Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
Specify the number of labels
along with the label mappings:
from transformers import ViltForQuestionAnswering
model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
At this point, only three steps remain:
Define your training hyperparameters in [TrainingArguments]:
from transformers import TrainingArguments
repo_id = "MariaK/vilt_finetuned_200"
training_args = TrainingArguments(
output_dir=repo_id,
per_device_train_batch_size=4,
num_train_epochs=20,
save_steps=200,
logging_steps=50,
learning_rate=5e-5,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=True,
)
Pass the training arguments to [Trainer] along with the model, dataset, processor, and data collator.