Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
You'll also need to specify the position of the <mask> token:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_mlm_model")
inputs = tokenizer(text, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
Pass your inputs to the model and return the logits of the masked token:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("username/my_awesome_eli5_mlm_model")
logits = model(**inputs).logits
mask_token_logits = logits[0, mask_token_index, :]
Then return the three masked tokens with the highest probability and print them out:
top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()
for token in top_3_tokens:
print(text.replace(tokenizer.mask_token, tokenizer.decode([token])))
The Milky Way is a spiral galaxy.