File size: 562 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200") image = Image.open(example['image_id']) question = example['question'] prepare inputs inputs = processor(image, question, return_tensors="pt") model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200") forward pass with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits idx = logits.argmax(-1).item() print("Predicted answer:", model.config.id2label[idx]) Predicted answer: down Zero-shot VQA The previous model treated VQA as a classification task. |