File size: 893 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
You'll also need to specify the position of the <mask> token: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_mlm_model") inputs = tokenizer(text, return_tensors="pt") mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] Pass your inputs to the model and return the logits of the masked token: from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("username/my_awesome_eli5_mlm_model") logits = model(**inputs).logits mask_token_logits = logits[0, mask_token_index, :] Then return the three masked tokens with the highest probability and print them out: top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist() for token in top_3_tokens: print(text.replace(tokenizer.mask_token, tokenizer.decode([token]))) The Milky Way is a spiral galaxy. |