numb3r3 commited on
Commit
b9e3a5c
·
verified ·
1 Parent(s): 6112129

init commit

Browse files
Files changed (1) hide show
  1. modeling.py +8 -2
modeling.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from torch import nn
 
3
  from typing import Optional, Tuple, List, Union
4
  from transformers import Qwen2VLForConditionalGeneration
5
  import logging
@@ -125,6 +126,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
125
  max_doc_length: Optional[int] = None,
126
  query_type: str = 'text',
127
  doc_type: str = 'text',
 
128
  show_progress: bool = False,
129
  ) -> List[float]:
130
 
@@ -211,8 +213,12 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
211
  # move the batch to the correct device
212
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
213
 
214
- scores = self.forward(**batch).view(-1).cpu().float().numpy().tolist()
215
- all_scores.extend(scores)
 
 
 
 
216
 
217
  if len(all_scores) == 1:
218
  return all_scores[0]
 
1
  import torch
2
  from torch import nn
3
+ import numpy as np
4
  from typing import Optional, Tuple, List, Union
5
  from transformers import Qwen2VLForConditionalGeneration
6
  import logging
 
126
  max_doc_length: Optional[int] = None,
127
  query_type: str = 'text',
128
  doc_type: str = 'text',
129
+ normalize_scores: bool = True,
130
  show_progress: bool = False,
131
  ) -> List[float]:
132
 
 
213
  # move the batch to the correct device
214
  batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
215
 
216
+ scores = self.forward(**batch).view(-1).cpu().float().numpy()
217
+
218
+ # normalize scores to [0, 1] with sigmoid
219
+ scores = 1.0 / (1.0 + np.exp(-scores))
220
+
221
+ all_scores.extend(scores.tolist())
222
 
223
  if len(all_scores) == 1:
224
  return all_scores[0]