def collate_fn(examples): | |
# permute to (num_frames, num_channels, height, width) | |
pixel_values = torch.stack( | |
[example["video"].permute(1, 0, 2, 3) for example in examples] | |
) | |
labels = torch.tensor([example["label"] for example in examples]) | |
return {"pixel_values": pixel_values, "labels": labels} | |
Then you just pass all of this along with the datasets to Trainer: | |
trainer = Trainer( | |
model, | |
args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
tokenizer=image_processor, | |
compute_metrics=compute_metrics, | |
data_collator=collate_fn, | |
) | |
You might wonder why you passed along the image_processor as a tokenizer when you preprocessed the data already. |