You'll also need to specify the position of the <mask> token: | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_mlm_model") | |
inputs = tokenizer(text, return_tensors="pt") | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
Pass your inputs to the model and return the logits of the masked token: | |
from transformers import AutoModelForMaskedLM | |
model = AutoModelForMaskedLM.from_pretrained("username/my_awesome_eli5_mlm_model") | |
logits = model(**inputs).logits | |
mask_token_logits = logits[0, mask_token_index, :] | |
Then return the three masked tokens with the highest probability and print them out: | |
top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist() | |
for token in top_3_tokens: | |
print(text.replace(tokenizer.mask_token, tokenizer.decode([token]))) | |
The Milky Way is a spiral galaxy. |