Specify the number of labels | |
along with the label mappings: | |
from transformers import ViltForQuestionAnswering | |
model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id) | |
At this point, only three steps remain: | |
Define your training hyperparameters in [TrainingArguments]: | |
from transformers import TrainingArguments | |
repo_id = "MariaK/vilt_finetuned_200" | |
training_args = TrainingArguments( | |
output_dir=repo_id, | |
per_device_train_batch_size=4, | |
num_train_epochs=20, | |
save_steps=200, | |
logging_steps=50, | |
learning_rate=5e-5, | |
save_total_limit=2, | |
remove_unused_columns=False, | |
push_to_hub=True, | |
) | |
Pass the training arguments to [Trainer] along with the model, dataset, processor, and data collator. |