processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type', 'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights']) | |
processed_dataset | |
Dataset({ | |
features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'], | |
num_rows: 200 | |
}) | |
As a final step, create a batch of examples using [DefaultDataCollator]: | |
from transformers import DefaultDataCollator | |
data_collator = DefaultDataCollator() | |
Train the model | |
You’re ready to start training your model now! |