numb3r3 commited on
Commit
d493c12
·
verified ·
1 Parent(s): 7fc0965

draft support image query

Browse files
Files changed (1) hide show
  1. modeling.py +17 -3
modeling.py CHANGED
@@ -183,10 +183,24 @@ class JinaVLForRanking(Qwen2VLForConditionalGeneration):
183
  batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))
184
 
185
  batch_images = None
 
 
 
 
 
 
 
186
  if doc_type == 'image':
187
- batch_images = load_images([d for (q, d) in mini_batch])
188
- elif query_type == 'image':
189
- batch_images = load_images([q for (q, d) in mini_batch])
 
 
 
 
 
 
 
190
 
191
  batch = self._processor(
192
  text=batch_inputs,
 
183
  batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))
184
 
185
  batch_images = None
186
+ # if doc_type == 'image':
187
+ # batch_images = load_images([d for (q, d) in mini_batch])
188
+ # elif query_type == 'image':
189
+ # batch_images = load_images([q for (q, d) in mini_batch])
190
+
191
+ doc_images = []
192
+ query_images = []
193
  if doc_type == 'image':
194
+ doc_images = load_images([d for (q, d) in mini_batch])
195
+ if query_type == 'image':
196
+ query_images = load_images([q for (q, d) in mini_batch])
197
+
198
+ if len(doc_images) == len(query_images) and len(doc_images) > 0:
199
+ batch_images = [[d, q] for q, d in zip(query_images, doc_images)]
200
+ elif len(doc_images) > 0:
201
+ batch_images = doc_images
202
+ elif len(query_images) > 0:
203
+ batch_images = query_images
204
 
205
  batch = self._processor(
206
  text=batch_inputs,