zerohell commited on
Commit
848ce04
·
1 Parent(s): 9cde1e1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -21,9 +21,10 @@ pipeline_tag: sentence-similarity
21
 
22
  ```python
23
  from transformers import BertModel, BertTokenizer
 
24
 
25
- model = BertModel.from_pretrained("tinydpr-acc_0.315-bs_307", cache_dir=".")
26
- tokenizer = BertTokenizer.from_pretrained("tinydpr-acc_0.315-bs_307", cache_dir=".")
27
 
28
  encoded_text = tokenizer(text="采用Dureader和cmrc2018数据集进行训练。", return_tensors="pt", max_length=512,
29
  padding='longest', truncation=True).to(model.device)
@@ -35,5 +36,5 @@ encoded_text = tokenizer(text="这个模型是采用什么数据集训练的?"
35
  encoded_text = {k: v.to(model.device) for (k, v) in encoded_text.items()}
36
  text_model_output2 = model(**encoded_text).pooler_output
37
  text_model_output2 = text_model_output2.cpu().detach().numpy()
38
- print(text_model_output1 @ text_model_output2)
39
  ```
 
21
 
22
  ```python
23
  from transformers import BertModel, BertTokenizer
24
+ import numpy as np
25
 
26
+ model = BertModel.from_pretrained("zerohell/tinydpr-acc_0.315-bs_307", cache_dir=".")
27
+ tokenizer = BertTokenizer.from_pretrained("zerohell/tinydpr-acc_0.315-bs_307", cache_dir=".")
28
 
29
  encoded_text = tokenizer(text="采用Dureader和cmrc2018数据集进行训练。", return_tensors="pt", max_length=512,
30
  padding='longest', truncation=True).to(model.device)
 
36
  encoded_text = {k: v.to(model.device) for (k, v) in encoded_text.items()}
37
  text_model_output2 = model(**encoded_text).pooler_output
38
  text_model_output2 = text_model_output2.cpu().detach().numpy()
39
+ print(text_model_output1 @ text_model_output2.T / np.linalg.norm(text_model_output1) / np.linalg.norm(text_model_output2))
40
  ```