Update README.md
Browse files
README.md
CHANGED
@@ -21,9 +21,10 @@ pipeline_tag: sentence-similarity
|
|
21 |
|
22 |
```python
|
23 |
from transformers import BertModel, BertTokenizer
|
|
|
24 |
|
25 |
-
model = BertModel.from_pretrained("tinydpr-acc_0.315-bs_307", cache_dir=".")
|
26 |
-
tokenizer = BertTokenizer.from_pretrained("tinydpr-acc_0.315-bs_307", cache_dir=".")
|
27 |
|
28 |
encoded_text = tokenizer(text="采用Dureader和cmrc2018数据集进行训练。", return_tensors="pt", max_length=512,
|
29 |
padding='longest', truncation=True).to(model.device)
|
@@ -35,5 +36,5 @@ encoded_text = tokenizer(text="这个模型是采用什么数据集训练的?"
|
|
35 |
encoded_text = {k: v.to(model.device) for (k, v) in encoded_text.items()}
|
36 |
text_model_output2 = model(**encoded_text).pooler_output
|
37 |
text_model_output2 = text_model_output2.cpu().detach().numpy()
|
38 |
-
print(text_model_output1 @ text_model_output2)
|
39 |
```
|
|
|
21 |
|
22 |
```python
|
23 |
from transformers import BertModel, BertTokenizer
|
24 |
+
import numpy as np
|
25 |
|
26 |
+
model = BertModel.from_pretrained("zerohell/tinydpr-acc_0.315-bs_307", cache_dir=".")
|
27 |
+
tokenizer = BertTokenizer.from_pretrained("zerohell/tinydpr-acc_0.315-bs_307", cache_dir=".")
|
28 |
|
29 |
encoded_text = tokenizer(text="采用Dureader和cmrc2018数据集进行训练。", return_tensors="pt", max_length=512,
|
30 |
padding='longest', truncation=True).to(model.device)
|
|
|
36 |
encoded_text = {k: v.to(model.device) for (k, v) in encoded_text.items()}
|
37 |
text_model_output2 = model(**encoded_text).pooler_output
|
38 |
text_model_output2 = text_model_output2.cpu().detach().numpy()
|
39 |
+
print(text_model_output1 @ text_model_output2.T / np.linalg.norm(text_model_output1) / np.linalg.norm(text_model_output2))
|
40 |
```
|