File size: 893 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
You'll also need to specify the position of the <mask> token:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("username/my_awesome_eli5_mlm_model")
inputs = tokenizer(text, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

Pass your inputs to the model and return the logits of the masked token:

from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("username/my_awesome_eli5_mlm_model")
logits = model(**inputs).logits
mask_token_logits = logits[0, mask_token_index, :]

Then return the three masked tokens with the highest probability and print them out:

top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()
for token in top_3_tokens:
     print(text.replace(tokenizer.mask_token, tokenizer.decode([token])))
The Milky Way is a spiral galaxy.