File size: 806 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
Specify the number of labels 
along with the label mappings:

from transformers import ViltForQuestionAnswering
model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)

At this point, only three steps remain:

Define your training hyperparameters in [TrainingArguments]:

from transformers import TrainingArguments
repo_id = "MariaK/vilt_finetuned_200"
training_args = TrainingArguments(
     output_dir=repo_id,
     per_device_train_batch_size=4,
     num_train_epochs=20,
     save_steps=200,
     logging_steps=50,
     learning_rate=5e-5,
     save_total_limit=2,
     remove_unused_columns=False,
     push_to_hub=True,
 )

Pass the training arguments to [Trainer] along with the model, dataset, processor, and data collator.