File size: 806 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
Specify the number of labels along with the label mappings: from transformers import ViltForQuestionAnswering model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id) At this point, only three steps remain: Define your training hyperparameters in [TrainingArguments]: from transformers import TrainingArguments repo_id = "MariaK/vilt_finetuned_200" training_args = TrainingArguments( output_dir=repo_id, per_device_train_batch_size=4, num_train_epochs=20, save_steps=200, logging_steps=50, learning_rate=5e-5, save_total_limit=2, remove_unused_columns=False, push_to_hub=True, ) Pass the training arguments to [Trainer] along with the model, dataset, processor, and data collator. |