Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
def collate_fn(examples):
# permute to (num_frames, num_channels, height, width)
pixel_values = torch.stack(
[example["video"].permute(1, 0, 2, 3) for example in examples]
)
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
Then you just pass all of this along with the datasets to Trainer:
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
You might wonder why you passed along the image_processor as a tokenizer when you preprocessed the data already.