init commit
Browse files- modeling.py +8 -2
modeling.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
|
|
3 |
from typing import Optional, Tuple, List, Union
|
4 |
from transformers import Qwen2VLForConditionalGeneration
|
5 |
import logging
|
@@ -125,6 +126,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
125 |
max_doc_length: Optional[int] = None,
|
126 |
query_type: str = 'text',
|
127 |
doc_type: str = 'text',
|
|
|
128 |
show_progress: bool = False,
|
129 |
) -> List[float]:
|
130 |
|
@@ -211,8 +213,12 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
|
|
211 |
# move the batch to the correct device
|
212 |
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
213 |
|
214 |
-
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
215 |
-
|
|
|
|
|
|
|
|
|
216 |
|
217 |
if len(all_scores) == 1:
|
218 |
return all_scores[0]
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
import numpy as np
|
4 |
from typing import Optional, Tuple, List, Union
|
5 |
from transformers import Qwen2VLForConditionalGeneration
|
6 |
import logging
|
|
|
126 |
max_doc_length: Optional[int] = None,
|
127 |
query_type: str = 'text',
|
128 |
doc_type: str = 'text',
|
129 |
+
normalize_scores: bool = True,
|
130 |
show_progress: bool = False,
|
131 |
) -> List[float]:
|
132 |
|
|
|
213 |
# move the batch to the correct device
|
214 |
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
215 |
|
216 |
+
scores = self.forward(**batch).view(-1).cpu().float().numpy()
|
217 |
+
|
218 |
+
# normalize scores to [0, 1] with sigmoid
|
219 |
+
scores = 1.0 / (1.0 + np.exp(-scores))
|
220 |
+
|
221 |
+
all_scores.extend(scores.tolist())
|
222 |
|
223 |
if len(all_scores) == 1:
|
224 |
return all_scores[0]
|