File size: 3,620 Bytes
09c1267
 
0ea7c30
 
 
 
 
 
 
09c1267
 
 
 
 
 
 
 
 
 
 
0ea7c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c99c916
0ea7c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c99c916
 
0ea7c30
 
c99c916
0ea7c30
 
 
 
 
 
 
 
09c1267
 
 
0ea7c30
 
 
 
 
 
09c1267
 
0ea7c30
 
09c1267
0ea7c30
09c1267
0ea7c30
 
09c1267
0ea7c30
 
 
 
 
09c1267
0ea7c30
 
09c1267
0ea7c30
 
09c1267
0ea7c30
 
09c1267
0ea7c30
 
 
09c1267
0ea7c30
 
 
09c1267
 
0ea7c30
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
---
library_name: transformers
license: apache-2.0
datasets:
- jaeyong2/Ja-emb-PreView
language:
- ja
base_model:
- Alibaba-NLP/gte-multilingual-base
---

# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->



## Model Details


## Train

- H/W : colab A100 40GB
- Data : jaeyong2/Ja-emb-PreView

```
model_name = "Alibaba-NLP/gte-multilingual-base"
dataset = datasets.load_dataset("jaeyong2/Ja-emb-PreView")
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(torch.bfloat16)
triplet_loss = TripletLoss(margin=1.0)

optimizer = AdamW(model.parameters(), lr=5e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(3):
    model.train()
    total_loss = 0
    count = 0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        loss = None
        for index in range(len(batch["context"])):
            anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt")
            positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
            negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")

            anchor_encodings = batch_to_device(anchor_encodings, device)
            positive_encodings = batch_to_device(positive_encodings, device)
            negative_encodings = batch_to_device(negative_encodings, device)

            
            anchor_output = model(**anchor_encodings)[0][:, 0, :]
            positive_output = model(**positive_encodings)[0][:, 0, :]
            negative_output = model(**negative_encodings)[0][:, 0, :]
            
            if loss==None:
                loss = triplet_loss(anchor_output, positive_output, negative_output)
            else:
                loss += triplet_loss(anchor_output, positive_output, negative_output)
        loss /= len(batch["context"])
        loss.backward()
        optimizer.step()
```

## Evaluation

Code : 
```
import torch
import numpy as np
from sklearn.metrics import pairwise_distances
from tqdm import tqdm


dataset = datasets.load_dataset("jaeyong2/Ja-emb-PreView")
validation_dataset = dataset["test"].select(range((1000)))

model.eval()

def evaluate(validation_dataset):
    correct_count = 0

    for item in tqdm(validation_dataset):
        query_embedding = get_embedding(item["context"], model, tokenizer)
        document_embedding = get_embedding(item["Title"], model, tokenizer)
        negative_embedding = get_embedding(item["Fake Title"], model, tokenizer)
      

        positive_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), document_embedding.detach().cpu().float().numpy(), metric="cosine")
        negative_distances = pairwise_distances(query_embedding.detach().cpu().float().numpy(), negative_embedding.detach().cpu().float().numpy(), metric="cosine")

        if positive_distances < negative_distances:
            correct_count += 1

    accuracy = correct_count / len(validation_dataset)
    return accuracy

results = evaluate(validation_dataset)
print(f"Validation Results: {results}")
```

Accuracy
- Alibaba-NLP/gte-multilingual-base : 0.979
- jaeyong2/gte-multilingual-base-Ja-embedding : 0.995


### License
- Alibaba-NLP/gte-multilingual-base : https://choosealicense.com/licenses/apache-2.0/