numb3r3 commited on
Commit
7fc0965
·
verified ·
1 Parent(s): 55db745

fix logit bias

Browse files
Files changed (2) hide show
  1. README.md +3 -3
  2. modeling.py +3 -3
README.md CHANGED
@@ -178,7 +178,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
178
  image_pairs = [[query, doc] for doc in documents]
179
 
180
  scores = model.compute_score(image_pairs, max_length=2048, doc_type="image")
181
- # [0.8576154708862305, 0.9356858730316162, 0.8496521711349487, 0.8664582967758179]
182
  ```
183
 
184
  **B. Textual Documents Reranking**
@@ -201,7 +201,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
201
  The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
202
  For instance the returning scores in this case will be:
203
  ```bash
204
- [0.9127850532531738, 0.8384682536125183, 0.8870794177055359, 0.842738926410675]
205
  ```
206
 
207
  **C. Image Querying for Textual Documents**
@@ -221,7 +221,7 @@ Compared to `jina-reranker-v2-base-multilingual`, `jina-reranker-m0` significant
221
  image_pairs = [[doc, query] for doc in documents]
222
  scores = model.compute_score(image_pairs, max_length=2048, doc_type="text")
223
 
224
- # [0.9048659801483154, 0.8266222476959229, 0.8326289653778076, 0.9075747132301331]
225
  ```
226
 
227
  # Model Performance
 
178
  image_pairs = [[query, doc] for doc in documents]
179
 
180
  scores = model.compute_score(image_pairs, max_length=2048, doc_type="image")
181
+ # [0.766852855682373, 0.9265167713165283, 0.7554926872253418, 0.7858350276947021]
182
  ```
183
 
184
  **B. Textual Documents Reranking**
 
201
  The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
202
  For instance the returning scores in this case will be:
203
  ```bash
204
+ [0.8778123259544373, 0.7254930734634399, 0.8271589875221252, 0.7437640428543091]
205
  ```
206
 
207
  **C. Image Querying for Textual Documents**
 
221
  image_pairs = [[doc, query] for doc in documents]
222
  scores = model.compute_score(image_pairs, max_length=2048, doc_type="text")
223
 
224
+ # [0.8673955798149109, 0.6999112367630005, 0.7031826972961426, 0.8744207620620728]
225
  ```
226
 
227
  # Model Performance
modeling.py CHANGED
@@ -10,7 +10,7 @@ from transformers.image_utils import load_image
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- LOGIT_SCALE = 0.68
14
 
15
  def load_images(images, lazy_load: bool = True):
16
  # Disable PIL DecompositionBomb threshold for reading large images.
@@ -123,7 +123,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
123
  pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
124
  batch_size: int = 8,
125
  max_length: int = 10240,
126
- max_query_length: int = 1024,
127
  max_doc_length: Optional[int] = None,
128
  query_type: str = 'text',
129
  doc_type: str = 'text',
@@ -219,7 +219,7 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
219
  scores = self.forward(**batch).view(-1).cpu().float().numpy()
220
 
221
  # normalize scores to [0, 1] with sigmoid with a scale
222
- scores = 1.0 / (1.0 + np.exp(-scores * LOGIT_SCALE))
223
 
224
  all_scores.extend(scores.tolist())
225
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ LOGIT_BIAS = 1.45 # logit bias for sigmoid normalization
14
 
15
  def load_images(images, lazy_load: bool = True):
16
  # Disable PIL DecompositionBomb threshold for reading large images.
 
123
  pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
124
  batch_size: int = 8,
125
  max_length: int = 10240,
126
+ max_query_length: int = 512,
127
  max_doc_length: Optional[int] = None,
128
  query_type: str = 'text',
129
  doc_type: str = 'text',
 
219
  scores = self.forward(**batch).view(-1).cpu().float().numpy()
220
 
221
  # normalize scores to [0, 1] with sigmoid with a scale
222
+ scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))
223
 
224
  all_scores.extend(scores.tolist())
225