thon | |
from transformers import TrainingArguments, Trainer | |
model_name = checkpoint.split("/")[1] | |
training_args = TrainingArguments( | |
output_dir=f"{model_name}-pokemon", | |
learning_rate=5e-5, | |
num_train_epochs=50, | |
fp16=True, | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
gradient_accumulation_steps=2, | |
save_total_limit=3, | |
evaluation_strategy="steps", | |
eval_steps=50, | |
save_strategy="steps", | |
save_steps=50, | |
logging_steps=50, | |
remove_unused_columns=False, | |
push_to_hub=True, | |
label_names=["labels"], | |
load_best_model_at_end=True, | |
) | |
Then pass them along with the datasets and the model to 🤗 Trainer. |