izhx commited on
Commit
ea01a1f
·
verified ·
1 Parent(s): 06bd79d

Integrate `trust_remote_code` and `sentence_transformers` (#9)

Browse files

- Integrate `trust_remote_code` and `sentence_transformers` (91b4ea93df6cefd45dde4ff65cdbe5e5dfc73b4f)
- Create 1_Pooling/config.json (f8efda76aad2fba36a7bb60d4d1d5afb461fe034)
- Update config.json (aa514959fedd5cef264246d24e2400aca0d67b82)
- Create config_sentence_transformers.json (eb0f9bec0d98e58568e098cacd619b304ef0a329)
- Create custom_st.py (64a2954f75ac37af17f7431e6d93753a4fa46ad3)
- Create modules.json (51f2ba553d618f18cb005d2e2e244ab821e14ba6)
- Update README.md (54e1aba8eec2520ba49966da794a703dde20a7f0)

1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 3584,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": true,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -3691,57 +3691,94 @@ The `GME` models support three types of input: **text**, **image**, and **image-
3691
  |[`gme-Qwen2-VL-2B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct) | 2.21B | 32768 | 1536 | 65.27 | 68.41 | 64.45 |
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 71.36 | 67.44 |
3693
 
 
 
3694
  ## Usage
3695
- **Use with custom code**
3696
 
3697
- ```python
3698
- # You can find the script gme_inference.py in https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
3699
- from gme_inference import GmeQwen2VL
3700
 
3701
- model = GmeQwen2VL('Alibaba-NLP/gme-Qwen2-VL-7B-Instruct')
3702
 
 
 
3703
  texts = [
3704
- "What kind of car is this?",
3705
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
3706
  ]
3707
  images = [
3708
- 'https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg',
3709
- 'https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
3710
  ]
3711
 
 
 
 
 
 
 
 
3712
  # Single-modal embedding
3713
  e_text = gme.get_text_embeddings(texts=texts)
3714
  e_image = gme.get_image_embeddings(images=images)
3715
- print((e_text * e_image).sum(-1))
3716
- ## tensor([0.1702, 0.5278], dtype=torch.float16)
3717
 
3718
  # How to set embedding instruction
3719
- e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
3720
  # If is_query=False, we always use the default instruction.
3721
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3722
- print((e_query * e_corpus).sum(-1))
3723
- ## tensor([0.2000, 0.5752], dtype=torch.float16)
3724
 
3725
  # Fused-modal embedding
3726
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3727
- print((e_fused[0] * e_fused[1]).sum())
3728
- ## tensor(0.6826, dtype=torch.float16)
3729
-
3730
  ```
3731
 
3732
- <!-- <details>
3733
- <summary>With transformers</summary>
 
 
 
 
3734
 
3735
  ```python
3736
- # Requires transformers>=4.46.2
 
3737
 
3738
- TODO
 
 
 
 
 
 
 
 
 
 
 
3739
 
3740
- # [[0.3016996383666992, 0.7503870129585266, 0.3203084468841553]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3741
  ```
3742
 
3743
- </details>
3744
- -->
3745
 
3746
  ## Evaluation
3747
 
 
3691
  |[`gme-Qwen2-VL-2B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct) | 2.21B | 32768 | 1536 | 65.27 | 68.41 | 64.45 |
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 71.36 | 67.44 |
3693
 
3694
+
3695
+
3696
  ## Usage
 
3697
 
 
 
 
3698
 
3699
+ **Transformers**
3700
 
3701
+ ```python
3702
+ t2i_prompt = 'Find an image that matches the given text.'
3703
  texts = [
3704
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3705
+ "Alibaba office.",
3706
  ]
3707
  images = [
3708
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3709
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3710
  ]
3711
 
3712
+
3713
+ gme = AutoModel.from_pretrained(
3714
+ "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3715
+ torch_dtype="float16", device_map='cuda', trust_remote_code=True
3716
+ )
3717
+
3718
+
3719
  # Single-modal embedding
3720
  e_text = gme.get_text_embeddings(texts=texts)
3721
  e_image = gme.get_image_embeddings(images=images)
3722
+ print('Single-modal', (e_text @ e_image.T).tolist())
3723
+ ## Single-modal [[0.279296875, 0.0002658367156982422], [0.06427001953125, 0.304443359375]]
3724
 
3725
  # How to set embedding instruction
3726
+ e_query = gme.get_text_embeddings(texts=texts, instruction=t2i_prompt)
3727
  # If is_query=False, we always use the default instruction.
3728
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3729
+ print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3730
+ ## Single-modal with instruction [[0.32861328125, 0.026336669921875], [0.09466552734375, 0.3134765625]]
3731
 
3732
  # Fused-modal embedding
3733
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3734
+ print('Fused-modal', (e_fused @ e_fused.T).tolist())
3735
+ ## Fused-modal [[1.0, 0.0308685302734375], [0.0308685302734375, 1.0]]
 
3736
  ```
3737
 
3738
+
3739
+ **sentence_transformers**
3740
+
3741
+ The `encode` function accept `str` or `dict` with key(s) in `{'text', 'image', 'prompt'}`.
3742
+
3743
+ **Do not pass `prompt` as the argument to `encode`**, pass as the input as a `dict` with a `prompt` key.
3744
 
3745
  ```python
3746
+ from sentence_transformers import SentenceTransformer
3747
+
3748
 
3749
+ t2i_prompt = 'Find an image that matches the given text.'
3750
+ texts = [
3751
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3752
+ "Alibaba office.",
3753
+ ]
3754
+ images = [
3755
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3756
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3757
+ ]
3758
+
3759
+
3760
+ gme_st = SentenceTransformer("Alibaba-NLP/gme-Qwen2-VL-7B-Instruct")
3761
 
3762
+ # Single-modal embedding
3763
+ e_text = gme_st.encode(texts, convert_to_tensor=True)
3764
+ e_image = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3765
+ print('Single-modal', (e_text @ e_image.T).tolist())
3766
+ ## Single-modal [[0.27880859375, 0.0005745887756347656], [0.06500244140625, 0.306640625]]
3767
+
3768
+ # How to set embedding instruction
3769
+ e_query = gme_st.encode([dict(text=t, prompt=t2i_prompt) for t in texts], convert_to_tensor=True)
3770
+ # If no prompt, we always use the default instruction.
3771
+ e_corpus = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3772
+ print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3773
+ ## Single-modal with instruction [[0.328369140625, 0.0269927978515625], [0.09521484375, 0.316162109375]]
3774
+
3775
+ # Fused-modal embedding
3776
+ e_fused = gme_st.encode([dict(text=t, image=i) for t, i in zip(texts, images)], convert_to_tensor=True)
3777
+ print('Fused-modal', (e_fused @ e_fused.T).tolist())
3778
+ ## Fused-modal [[0.99951171875, 0.0311737060546875], [0.0311737060546875, 1.0009765625]]
3779
  ```
3780
 
3781
+
 
3782
 
3783
  ## Evaluation
3784
 
config.json CHANGED
@@ -1,8 +1,10 @@
1
  {
2
- "_name_or_path": "gme-Qwen2-VL-7B-Instruct",
3
- "architectures": [
4
- "Qwen2VLForConditionalGeneration"
5
- ],
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eos_token_id": 151645,
 
1
  {
2
+ "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3
+ "architectures": ["GmeQwen2VLForVision2Seq"],
4
+ "auto_map": {
5
+ "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VLForVision2Seq",
6
+ "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig"
7
+ },
8
  "attention_dropout": 0.0,
9
  "bos_token_id": 151643,
10
  "eos_token_id": 151645,
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompts": {
3
+ "query": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
4
+ },
5
+ "default_prompt_name": null,
6
+ "similarity_fn_name": null
7
+ }
custom_st.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Dict, Optional, List
3
+ import torch
4
+ from PIL import Image
5
+ from sentence_transformers.models import Transformer as BaseTransformer
6
+ from transformers import AutoModelForVision2Seq, AutoProcessor
7
+
8
+
9
+ class MultiModalTransformer(BaseTransformer):
10
+ def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ min_image_tokens: int = 256,
16
+ max_image_tokens: int = 1280,
17
+ max_length: int = 1800,
18
+ **kwargs,
19
+ ):
20
+ super().__init__(model_name_or_path, **kwargs)
21
+ if tokenizer_args is None:
22
+ tokenizer_args = {}
23
+ tokenizer_args.pop("trust_remote_code", None)
24
+
25
+ # Initialize processor
26
+ min_pixels = min_image_tokens * 28 * 28
27
+ max_pixels = max_image_tokens * 28 * 28
28
+ self.processor = AutoProcessor.from_pretrained(
29
+ model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
30
+ )
31
+ self.processor.tokenizer.padding_side = 'right'
32
+ self.sep = ' '
33
+ self.max_length = max_length
34
+ self.normalize = True
35
+
36
+ def _load_model(
37
+ self,
38
+ model_name_or_path: str,
39
+ config,
40
+ cache_dir: str,
41
+ backend: str,
42
+ is_peft_model: bool,
43
+ **model_args,
44
+ ) -> None:
45
+ model_args.pop("trust_remote_code", None)
46
+ self.auto_model = AutoModelForVision2Seq.from_pretrained(
47
+ model_name_or_path, torch_dtype=torch.float16, **model_args
48
+ )
49
+
50
+ def forward(
51
+ self, features: Dict[str, torch.Tensor], **kwargs
52
+ ) -> Dict[str, torch.Tensor]:
53
+ if features.get("inputs_embeds", None) is None:
54
+ features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
+ if features.get("pixel_values", None) is not None:
56
+ features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
+ image_embeds = self.auto_model.visual(
58
+ features["pixel_values"], grid_thw=features["image_grid_thw"]
59
+ )
60
+ image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
+ features["inputs_embeds"][image_mask] = image_embeds
62
+ # features.pop("pixel_values")
63
+ # features.pop("image_grid_thw")
64
+ # features.pop("input_ids")
65
+ inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
66
+ outputs = self.auto_model.model(
67
+ **inputs,
68
+ return_dict=True,
69
+ output_hidden_states=True,
70
+ # **kwargs
71
+ )
72
+ # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
73
+ # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
74
+ # if left_padding:
75
+ # embeddings = outputs.last_hidden_state
76
+ # else:
77
+ # sequence_lengths = pooling_mask.sum(dim=1) - 1
78
+ # embeddings = outputs.last_hidden_state[torch.arange(
79
+ # outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
80
+ # ), sequence_lengths]
81
+ features.update({"token_embeddings": outputs.last_hidden_state})
82
+ return features
83
+
84
+ def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
85
+ default_instruction = 'You are a helpful assistant.'
86
+
87
+ all_texts, all_images = list(), list()
88
+ for item in texts:
89
+ if isinstance(item, str):
90
+ txt, img, inst = item, None, default_instruction
91
+ elif isinstance(item, dict):
92
+ txt = item.get('text', None)
93
+ img = item.get('image', None)
94
+ inst = item.get('prompt', default_instruction)
95
+ else:
96
+ raise RuntimeError(f'Input format not supported! {item=}')
97
+
98
+ input_str = ''
99
+ if img is None:
100
+ all_images = None # All examples in the same batch are consistent
101
+ # or will have ValueError: Could not make a flat list of images from xxxx
102
+ else:
103
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
104
+ img = fetch_image(img)
105
+ all_images.append(img)
106
+ if txt is not None:
107
+ input_str += txt
108
+ msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
109
+ all_texts.append(msg)
110
+
111
+ inputs = self.processor(
112
+ text=all_texts,
113
+ images=all_images,
114
+ padding="longest",
115
+ truncation=True,
116
+ max_length=self.max_seq_length,
117
+ return_tensors='pt'
118
+ )
119
+ return inputs
120
+
121
+
122
+ ### Copied from qwen_vl_utils.vision_process.py
123
+ import base64
124
+ from io import BytesIO
125
+ import requests
126
+
127
+ IMAGE_FACTOR = 28
128
+ MIN_PIXELS = 4 * 28 * 28
129
+ MAX_PIXELS = 16384 * 28 * 28
130
+ MAX_RATIO = 200
131
+
132
+
133
+ def round_by_factor(number: int, factor: int) -> int:
134
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
135
+ return round(number / factor) * factor
136
+
137
+
138
+ def ceil_by_factor(number: int, factor: int) -> int:
139
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
140
+ return math.ceil(number / factor) * factor
141
+
142
+
143
+ def floor_by_factor(number: int, factor: int) -> int:
144
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
145
+ return math.floor(number / factor) * factor
146
+
147
+
148
+ def smart_resize(
149
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
150
+ ) -> tuple[int, int]:
151
+ """
152
+ Rescales the image so that the following conditions are met:
153
+
154
+ 1. Both dimensions (height and width) are divisible by 'factor'.
155
+
156
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
157
+
158
+ 3. The aspect ratio of the image is maintained as closely as possible.
159
+ """
160
+ h_bar = max(factor, round_by_factor(height, factor))
161
+ w_bar = max(factor, round_by_factor(width, factor))
162
+ if h_bar * w_bar > max_pixels:
163
+ beta = math.sqrt((height * width) / max_pixels)
164
+ h_bar = floor_by_factor(height / beta, factor)
165
+ w_bar = floor_by_factor(width / beta, factor)
166
+ elif h_bar * w_bar < min_pixels:
167
+ beta = math.sqrt(min_pixels / (height * width))
168
+ h_bar = ceil_by_factor(height * beta, factor)
169
+ w_bar = ceil_by_factor(width * beta, factor)
170
+
171
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
172
+ logging.warning(
173
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
174
+ )
175
+ if h_bar > w_bar:
176
+ h_bar = w_bar * MAX_RATIO
177
+ else:
178
+ w_bar = h_bar * MAX_RATIO
179
+ return h_bar, w_bar
180
+
181
+
182
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
183
+ image_obj = None
184
+ if isinstance(image, Image.Image):
185
+ image_obj = image
186
+ elif image.startswith("http://") or image.startswith("https://"):
187
+ image_obj = Image.open(requests.get(image, stream=True).raw)
188
+ elif image.startswith("file://"):
189
+ image_obj = Image.open(image[7:])
190
+ elif image.startswith("data:image"):
191
+ if "base64," in image:
192
+ _, base64_data = image.split("base64,", 1)
193
+ data = base64.b64decode(base64_data)
194
+ image_obj = Image.open(BytesIO(data))
195
+ else:
196
+ image_obj = Image.open(image)
197
+ if image_obj is None:
198
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
199
+ image = image_obj.convert("RGB")
200
+ ## resize
201
+ # if "resized_height" in ele and "resized_width" in ele:
202
+ # resized_height, resized_width = smart_resize(
203
+ # ele["resized_height"],
204
+ # ele["resized_width"],
205
+ # factor=size_factor,
206
+ # )
207
+ # else:
208
+ width, height = image.size
209
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
210
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
211
+ resized_height, resized_width = smart_resize(
212
+ height,
213
+ width,
214
+ factor=size_factor,
215
+ min_pixels=MIN_PIXELS,
216
+ max_pixels=MAX_PIXELS,
217
+ )
218
+ image = image.resize((resized_width, resized_height))
219
+
220
+ return image
221
+ ###
modeling_gme_qwen2vl.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ from io import BytesIO
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import requests
11
+ import torch
12
+ from PIL import Image
13
+ from torch.utils.data import DataLoader
14
+ from tqdm.autonotebook import tqdm
15
+ from transformers import (
16
+ AutoProcessor,
17
+ PreTrainedModel,
18
+ Qwen2VLConfig,
19
+ Qwen2VLForConditionalGeneration,
20
+ )
21
+ import os
22
+
23
+
24
+ class GmeQwen2VLConfig(Qwen2VLConfig):
25
+ def __init__(
26
+ self,
27
+ min_image_tokens: int = 256,
28
+ max_image_tokens: int = 1280,
29
+ max_length: int = 1800,
30
+ **kwargs: Any,
31
+ ) -> None:
32
+ super().__init__(**kwargs)
33
+ self.min_image_tokens = min_image_tokens
34
+ self.max_image_tokens = max_image_tokens
35
+ self.max_length = max_length
36
+
37
+
38
+ class GmeQwen2VLForVision2Seq(PreTrainedModel):
39
+ config_class = GmeQwen2VLConfig
40
+ base_model_prefix: str = "base"
41
+
42
+ def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
43
+ super().__init__(config)
44
+ self.base = Qwen2VLForConditionalGeneration.from_pretrained(config._name_or_path)
45
+ self.base.tie_weights() # It's important to produce same outputs.
46
+
47
+ min_pixels: int = config.min_image_tokens * 28 * 28
48
+ max_pixels: int = config.max_image_tokens * 28 * 28
49
+ self.processor = AutoProcessor.from_pretrained(
50
+ config._name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
51
+ )
52
+ self.max_length: int = config.max_length
53
+ self.normalize: bool = True
54
+ self.processor.tokenizer.padding_side = "right"
55
+ self.default_instruction: str = "You are a helpful assistant."
56
+ self.sep: str = " "
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: Optional[torch.LongTensor] = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ position_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
64
+ inputs_embeds: Optional[torch.FloatTensor] = None,
65
+ pixel_values: Optional[torch.Tensor] = None,
66
+ # pixel_values_videos: Optional[torch.FloatTensor] = None,
67
+ image_grid_thw: Optional[torch.LongTensor] = None,
68
+ # video_grid_thw: Optional[torch.LongTensor] = None,
69
+ pooling_mask: Optional[torch.LongTensor] = None,
70
+ **kwargs
71
+ ) -> torch.Tensor:
72
+ if inputs_embeds is None:
73
+ inputs_embeds = self.base.model.embed_tokens(input_ids)
74
+ if pixel_values is not None:
75
+ pixel_values = pixel_values.type(self.base.visual.get_dtype())
76
+ image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
77
+ image_mask = input_ids == self.base.config.image_token_id
78
+ inputs_embeds[image_mask] = image_embeds
79
+ # if pixel_values_videos is not None:
80
+ # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
81
+ # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
82
+ # video_mask = input_ids == self.base.config.video_token_id
83
+ # inputs_embeds[video_mask] = video_embeds
84
+ if attention_mask is not None:
85
+ attention_mask = attention_mask.to(inputs_embeds.device)
86
+
87
+ outputs = self.base.model(
88
+ input_ids=None,
89
+ position_ids=position_ids,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ )
94
+
95
+ pooling_mask = attention_mask if pooling_mask is None else pooling_mask
96
+ left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
97
+ if left_padding:
98
+ embeddings = outputs.last_hidden_state[:, -1]
99
+ else:
100
+ sequence_lengths = pooling_mask.sum(dim=1) - 1
101
+ batch_size = outputs.last_hidden_state.shape[0]
102
+ embeddings = outputs.last_hidden_state[torch.arange(
103
+ batch_size, device=outputs.last_hidden_state.device
104
+ ), sequence_lengths]
105
+ if self.normalize:
106
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
107
+ return embeddings.contiguous()
108
+
109
+ def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
110
+ self.eval()
111
+ # Inputs must be batched
112
+ input_texts, input_images = list(), list()
113
+ for t, i in zip(texts, images):
114
+ if not is_query or instruction is None:
115
+ instruction = self.default_instruction
116
+ input_str = ''
117
+ if i is None:
118
+ input_images = None # All examples in the same batch are consistent
119
+ else:
120
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
121
+ i = fetch_image(i)
122
+ input_images.append(i)
123
+ if t is not None:
124
+ input_str += t
125
+ msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
126
+ input_texts.append(msg)
127
+
128
+ inputs = self.processor(
129
+ text=input_texts,
130
+ images=input_images,
131
+ padding=True,
132
+ truncation=True,
133
+ max_length=self.max_length,
134
+ return_tensors='pt'
135
+ )
136
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
137
+ with torch.inference_mode():
138
+ embeddings = self.forward(**inputs)
139
+ return embeddings
140
+
141
+ def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
142
+ return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
143
+
144
+ def encode_queries(self, queries: List[str], **kwargs):
145
+ embeddings = self.encode(queries, **kwargs)
146
+ return embeddings
147
+
148
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
149
+ if type(corpus) is dict:
150
+ sentences = [
151
+ (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
152
+ if "title" in corpus
153
+ else corpus["text"][i].strip()
154
+ for i in range(len(corpus["text"]))
155
+ ]
156
+ else:
157
+ sentences = [
158
+ (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
159
+ for doc in corpus
160
+ ]
161
+ embeddings = self.encode(sentences, is_query=False, **kwargs)
162
+ return embeddings
163
+
164
+ def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
165
+ return self.get_fused_embeddings(images=images, **kwargs)
166
+
167
+ def get_text_embeddings(self, texts: list[str], **kwargs):
168
+ return self.get_fused_embeddings(texts=texts, **kwargs)
169
+
170
+ def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
171
+ if isinstance(images, DataLoader):
172
+ image_loader = images
173
+ batch_size = image_loader.batch_size
174
+ image_loader.dataset.transform = None
175
+ else:
176
+ batch_size = kwargs.pop('batch_size', 32)
177
+ if images is None:
178
+ image_loader = None
179
+ else:
180
+ image_loader = DataLoader(
181
+ images,
182
+ batch_size=batch_size,
183
+ shuffle=False,
184
+ collate_fn=custom_collate_fn,
185
+ num_workers=min(math.floor(os.cpu_count() / 2), 8),
186
+ )
187
+
188
+ if texts is None:
189
+ assert image_loader is not None
190
+ n_batch = len(image_loader)
191
+ else:
192
+ n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
193
+ image_loader = image_loader or [None] * n_batch
194
+
195
+ all_embeddings = list()
196
+ none_batch = [None] * batch_size
197
+ show_progress_bar = kwargs.pop('show_progress_bar', False)
198
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
199
+ for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
200
+ text_batch = none_batch if texts is None else texts[n: n+batch_size]
201
+ img_batch = none_batch if img_batch is None else img_batch
202
+ embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
203
+ pbar.update(1)
204
+ all_embeddings.append(embeddings.cpu())
205
+ pbar.close()
206
+ all_embeddings = torch.cat(all_embeddings, dim=0)
207
+ return all_embeddings
208
+
209
+
210
+ def custom_collate_fn(batch):
211
+ return batch
212
+
213
+
214
+ ### Copied from qwen_vl_utils.vision_process.py
215
+ import base64
216
+ from io import BytesIO
217
+ import requests
218
+
219
+ IMAGE_FACTOR = 28
220
+ MIN_PIXELS = 4 * 28 * 28
221
+ MAX_PIXELS = 16384 * 28 * 28
222
+ MAX_RATIO = 200
223
+
224
+
225
+ def round_by_factor(number: int, factor: int) -> int:
226
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
227
+ return round(number / factor) * factor
228
+
229
+
230
+ def ceil_by_factor(number: int, factor: int) -> int:
231
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
232
+ return math.ceil(number / factor) * factor
233
+
234
+
235
+ def floor_by_factor(number: int, factor: int) -> int:
236
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
237
+ return math.floor(number / factor) * factor
238
+
239
+
240
+ def smart_resize(
241
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
242
+ ) -> tuple[int, int]:
243
+ """
244
+ Rescales the image so that the following conditions are met:
245
+
246
+ 1. Both dimensions (height and width) are divisible by 'factor'.
247
+
248
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
249
+
250
+ 3. The aspect ratio of the image is maintained as closely as possible.
251
+ """
252
+ h_bar = max(factor, round_by_factor(height, factor))
253
+ w_bar = max(factor, round_by_factor(width, factor))
254
+ if h_bar * w_bar > max_pixels:
255
+ beta = math.sqrt((height * width) / max_pixels)
256
+ h_bar = floor_by_factor(height / beta, factor)
257
+ w_bar = floor_by_factor(width / beta, factor)
258
+ elif h_bar * w_bar < min_pixels:
259
+ beta = math.sqrt(min_pixels / (height * width))
260
+ h_bar = ceil_by_factor(height * beta, factor)
261
+ w_bar = ceil_by_factor(width * beta, factor)
262
+
263
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
264
+ logging.warning(
265
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
266
+ )
267
+ if h_bar > w_bar:
268
+ h_bar = w_bar * MAX_RATIO
269
+ else:
270
+ w_bar = h_bar * MAX_RATIO
271
+ return h_bar, w_bar
272
+
273
+
274
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
275
+ image_obj = None
276
+ if isinstance(image, Image.Image):
277
+ image_obj = image
278
+ elif image.startswith("http://") or image.startswith("https://"):
279
+ image_obj = Image.open(requests.get(image, stream=True).raw)
280
+ elif image.startswith("file://"):
281
+ image_obj = Image.open(image[7:])
282
+ elif image.startswith("data:image"):
283
+ if "base64," in image:
284
+ _, base64_data = image.split("base64,", 1)
285
+ data = base64.b64decode(base64_data)
286
+ image_obj = Image.open(BytesIO(data))
287
+ else:
288
+ image_obj = Image.open(image)
289
+ if image_obj is None:
290
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
291
+ image = image_obj.convert("RGB")
292
+ ## resize
293
+ # if "resized_height" in ele and "resized_width" in ele:
294
+ # resized_height, resized_width = smart_resize(
295
+ # ele["resized_height"],
296
+ # ele["resized_width"],
297
+ # factor=size_factor,
298
+ # )
299
+ # else:
300
+ width, height = image.size
301
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
302
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
303
+ resized_height, resized_width = smart_resize(
304
+ height,
305
+ width,
306
+ factor=size_factor,
307
+ min_pixels=MIN_PIXELS,
308
+ max_pixels=MAX_PIXELS,
309
+ )
310
+ image = image.resize((resized_width, resized_height))
311
+
312
+ return image
313
+ ###
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]