thon | |
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel | |
from datasets import load_dataset | |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") | |
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( | |
"google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased" | |
) | |
model.config.decoder_start_token_id = tokenizer.cls_token_id | |
model.config.pad_token_id = tokenizer.pad_token_id | |
dataset = load_dataset("huggingface/cats-image") | |
image = dataset["test"]["image"][0] | |
pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
labels = tokenizer( | |
"an image of two cats chilling on a couch", | |
return_tensors="pt", | |
).input_ids | |
the forward function automatically creates the correct decoder_input_ids | |
loss = model(pixel_values=pixel_values, labels=labels).loss | |
This model was contributed by nielsr. |