izhx Samoed commited on
Commit
73dd060
·
verified ·
1 Parent(s): c937797

Integrate to AutoModel (#10)

Browse files

- Try to integrate AutoModel (24370b049274f4d06c01d479db9cae41ade39e56)
- Update gme_inference.py (02d3017aa1d162ad0ee332345548a1b28351e07a)
- Update README.md (2944f212de0c1ccbe1bc993d660f4a3919b532ee)
- Update README.md (5fa49f4b1ef97eacaf12285d8aca835728b89f68)
- lint (e0e7250f5772018b48d567e195e682d7f4a52586)
- Create modeling_gme_qwen2vl.py (fd9d1a41f746fa74148ccffaf660f7378a0a1026)
- Update config.json (e93daf4018038593fb92d37d3fd1c2fe61f88d12)
- revert gme_inference.py (facd32fbf0276e7365b453eb1400abdf5c0fee5e)
- Update README.md (9e37c1412199889941439551345ead1317bc7e46)
- revert README.md (4e639a8908b3d2d559abe9feaebe99cd19c5806e)


Co-authored-by: Solomatin Roman <Samoed@users.noreply.huggingface.co>

Files changed (3) hide show
  1. README.md +1 -1
  2. config.json +9 -11
  3. modeling_gme_qwen2vl.py +314 -0
README.md CHANGED
@@ -3879,4 +3879,4 @@ If you find our paper or models helpful, please consider cite:
3879
  primaryClass={cs.CL},
3880
  url={http://arxiv.org/abs/2412.16855},
3881
  }
3882
- ```
 
3879
  primaryClass={cs.CL},
3880
  url={http://arxiv.org/abs/2412.16855},
3881
  }
3882
+ ```
config.json CHANGED
@@ -1,8 +1,10 @@
1
  {
2
- "_name_or_path": "gme-Qwen2-VL-2B-Instruct",
3
- "architectures": [
4
- "Qwen2VLForConditionalGeneration"
5
- ],
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eos_token_id": 151645,
@@ -13,17 +15,13 @@
13
  "intermediate_size": 8960,
14
  "max_position_embeddings": 32768,
15
  "max_window_layers": 28,
16
- "model_type": "qwen2_vl",
17
  "num_attention_heads": 12,
18
  "num_hidden_layers": 28,
19
  "num_key_value_heads": 2,
20
- "rms_norm_eps": 1e-06,
21
  "rope_scaling": {
22
- "mrope_section": [
23
- 16,
24
- 24,
25
- 24
26
- ],
27
  "type": "mrope"
28
  },
29
  "rope_theta": 1000000.0,
 
1
  {
2
+ "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
3
+ "architectures": ["GmeQwen2VLForVision2Seq"],
4
+ "auto_map": {
5
+ "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VLForVision2Seq",
6
+ "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig"
7
+ },
8
  "attention_dropout": 0.0,
9
  "bos_token_id": 151643,
10
  "eos_token_id": 151645,
 
15
  "intermediate_size": 8960,
16
  "max_position_embeddings": 32768,
17
  "max_window_layers": 28,
18
+ "model_type": "gme_qwen2_vl",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 28,
21
  "num_key_value_heads": 2,
22
+ "rms_norm_eps": 1e-6,
23
  "rope_scaling": {
24
+ "mrope_section": [16, 24, 24],
 
 
 
 
25
  "type": "mrope"
26
  },
27
  "rope_theta": 1000000.0,
modeling_gme_qwen2vl.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ from io import BytesIO
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import requests
11
+ import torch
12
+ from PIL import Image
13
+ from torch.utils.data import DataLoader
14
+ from tqdm.autonotebook import tqdm
15
+ from transformers import (
16
+ AutoProcessor,
17
+ PreTrainedModel,
18
+ Qwen2VLConfig,
19
+ Qwen2VLForConditionalGeneration,
20
+ )
21
+ import os
22
+
23
+
24
+ class GmeQwen2VLConfig(Qwen2VLConfig):
25
+ def __init__(
26
+ self,
27
+ min_image_tokens: int = 256,
28
+ max_image_tokens: int = 1280,
29
+ max_length: int = 1800,
30
+ **kwargs: Any,
31
+ ) -> None:
32
+ super().__init__(**kwargs)
33
+ self.min_image_tokens = min_image_tokens
34
+ self.max_image_tokens = max_image_tokens
35
+ self.max_length = max_length
36
+
37
+
38
+ class GmeQwen2VLForVision2Seq(PreTrainedModel):
39
+ config_class = GmeQwen2VLConfig
40
+ base_model_prefix: str = "base"
41
+
42
+ def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
43
+ super().__init__(config)
44
+ self.base = Qwen2VLForConditionalGeneration.from_pretrained(config._name_or_path)
45
+ self.base.tie_weights() # It's important to produce same outputs.
46
+
47
+ min_pixels: int = config.min_image_tokens * 28 * 28
48
+ max_pixels: int = config.max_image_tokens * 28 * 28
49
+ self.processor = AutoProcessor.from_pretrained(
50
+ config._name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
51
+ )
52
+ self.max_length: int = config.max_length
53
+ self.normalize: bool = True
54
+ self.processor.tokenizer.padding_side = "right"
55
+ self.default_instruction: str = "You are a helpful assistant."
56
+ self.sep: str = " "
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: Optional[torch.LongTensor] = None,
61
+ attention_mask: Optional[torch.Tensor] = None,
62
+ position_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
64
+ inputs_embeds: Optional[torch.FloatTensor] = None,
65
+ pixel_values: Optional[torch.Tensor] = None,
66
+ # pixel_values_videos: Optional[torch.FloatTensor] = None,
67
+ image_grid_thw: Optional[torch.LongTensor] = None,
68
+ # video_grid_thw: Optional[torch.LongTensor] = None,
69
+ pooling_mask: Optional[torch.LongTensor] = None,
70
+ **kwargs
71
+ ) -> torch.Tensor:
72
+ if inputs_embeds is None:
73
+ inputs_embeds = self.base.model.embed_tokens(input_ids)
74
+ if pixel_values is not None:
75
+ pixel_values = pixel_values.type(self.base.visual.get_dtype())
76
+ image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
77
+ image_mask = input_ids == self.base.config.image_token_id
78
+ inputs_embeds[image_mask] = image_embeds
79
+ # if pixel_values_videos is not None:
80
+ # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
81
+ # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
82
+ # video_mask = input_ids == self.base.config.video_token_id
83
+ # inputs_embeds[video_mask] = video_embeds
84
+ if attention_mask is not None:
85
+ attention_mask = attention_mask.to(inputs_embeds.device)
86
+
87
+ outputs = self.base.model(
88
+ input_ids=None,
89
+ position_ids=position_ids,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ )
94
+
95
+ pooling_mask = attention_mask if pooling_mask is None else pooling_mask
96
+ left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
97
+ if left_padding:
98
+ embeddings = outputs.last_hidden_state[:, -1]
99
+ else:
100
+ sequence_lengths = pooling_mask.sum(dim=1) - 1
101
+ batch_size = outputs.last_hidden_state.shape[0]
102
+ embeddings = outputs.last_hidden_state[torch.arange(
103
+ batch_size, device=outputs.last_hidden_state.device
104
+ ), sequence_lengths]
105
+ if self.normalize:
106
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
107
+ return embeddings.contiguous()
108
+
109
+ def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
110
+ self.eval()
111
+ # Inputs must be batched
112
+ input_texts, input_images = list(), list()
113
+ for t, i in zip(texts, images):
114
+ if not is_query or instruction is None:
115
+ instruction = self.default_instruction
116
+ input_str = ''
117
+ if i is None:
118
+ input_images = None # All examples in the same batch are consistent
119
+ else:
120
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
121
+ i = fetch_image(i)
122
+ input_images.append(i)
123
+ if t is not None:
124
+ input_str += t
125
+ msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
126
+ input_texts.append(msg)
127
+
128
+ inputs = self.processor(
129
+ text=input_texts,
130
+ images=input_images,
131
+ padding=True,
132
+ truncation=True,
133
+ max_length=self.max_length,
134
+ return_tensors='pt'
135
+ )
136
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
137
+ with torch.inference_mode():
138
+ embeddings = self.forward(**inputs)
139
+ return embeddings
140
+
141
+ def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
142
+ return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
143
+
144
+ def encode_queries(self, queries: List[str], **kwargs):
145
+ embeddings = self.encode(queries, **kwargs)
146
+ return embeddings
147
+
148
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
149
+ if type(corpus) is dict:
150
+ sentences = [
151
+ (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
152
+ if "title" in corpus
153
+ else corpus["text"][i].strip()
154
+ for i in range(len(corpus["text"]))
155
+ ]
156
+ else:
157
+ sentences = [
158
+ (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
159
+ for doc in corpus
160
+ ]
161
+ embeddings = self.encode(sentences, is_query=False, **kwargs)
162
+ return embeddings
163
+
164
+ def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
165
+ return self.get_fused_embeddings(images=images, **kwargs)
166
+
167
+ def get_text_embeddings(self, texts: list[str], **kwargs):
168
+ return self.get_fused_embeddings(texts=texts, **kwargs)
169
+
170
+ def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
171
+ if isinstance(images, DataLoader):
172
+ image_loader = images
173
+ batch_size = image_loader.batch_size
174
+ image_loader.dataset.transform = None
175
+ else:
176
+ batch_size = kwargs.pop('batch_size', 32)
177
+ if images is None:
178
+ image_loader = None
179
+ else:
180
+ image_loader = DataLoader(
181
+ images,
182
+ batch_size=batch_size,
183
+ shuffle=False,
184
+ collate_fn=custom_collate_fn,
185
+ num_workers=min(math.floor(os.cpu_count() / 2), 8),
186
+ )
187
+
188
+ if texts is None:
189
+ assert image_loader is not None
190
+ n_batch = len(image_loader)
191
+ else:
192
+ n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
193
+ image_loader = image_loader or [None] * n_batch
194
+
195
+ all_embeddings = list()
196
+ none_batch = [None] * batch_size
197
+ show_progress_bar = kwargs.pop('show_progress_bar', False)
198
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
199
+ for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
200
+ text_batch = none_batch if texts is None else texts[n: n+batch_size]
201
+ img_batch = none_batch if img_batch is None else img_batch
202
+ embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
203
+ pbar.update(1)
204
+ all_embeddings.append(embeddings.cpu())
205
+ pbar.close()
206
+ all_embeddings = torch.cat(all_embeddings, dim=0)
207
+ return all_embeddings
208
+
209
+
210
+ def custom_collate_fn(batch):
211
+ return batch
212
+
213
+
214
+ ### Copied from qwen_vl_utils.vision_process.py
215
+ import base64
216
+ from io import BytesIO
217
+ import requests
218
+
219
+ IMAGE_FACTOR = 28
220
+ MIN_PIXELS = 4 * 28 * 28
221
+ MAX_PIXELS = 16384 * 28 * 28
222
+ MAX_RATIO = 200
223
+
224
+
225
+ def round_by_factor(number: int, factor: int) -> int:
226
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
227
+ return round(number / factor) * factor
228
+
229
+
230
+ def ceil_by_factor(number: int, factor: int) -> int:
231
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
232
+ return math.ceil(number / factor) * factor
233
+
234
+
235
+ def floor_by_factor(number: int, factor: int) -> int:
236
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
237
+ return math.floor(number / factor) * factor
238
+
239
+
240
+ def smart_resize(
241
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
242
+ ) -> tuple[int, int]:
243
+ """
244
+ Rescales the image so that the following conditions are met:
245
+
246
+ 1. Both dimensions (height and width) are divisible by 'factor'.
247
+
248
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
249
+
250
+ 3. The aspect ratio of the image is maintained as closely as possible.
251
+ """
252
+ h_bar = max(factor, round_by_factor(height, factor))
253
+ w_bar = max(factor, round_by_factor(width, factor))
254
+ if h_bar * w_bar > max_pixels:
255
+ beta = math.sqrt((height * width) / max_pixels)
256
+ h_bar = floor_by_factor(height / beta, factor)
257
+ w_bar = floor_by_factor(width / beta, factor)
258
+ elif h_bar * w_bar < min_pixels:
259
+ beta = math.sqrt(min_pixels / (height * width))
260
+ h_bar = ceil_by_factor(height * beta, factor)
261
+ w_bar = ceil_by_factor(width * beta, factor)
262
+
263
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
264
+ logging.warning(
265
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
266
+ )
267
+ if h_bar > w_bar:
268
+ h_bar = w_bar * MAX_RATIO
269
+ else:
270
+ w_bar = h_bar * MAX_RATIO
271
+ return h_bar, w_bar
272
+
273
+
274
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
275
+ image_obj = None
276
+ if isinstance(image, Image.Image):
277
+ image_obj = image
278
+ elif image.startswith("http://") or image.startswith("https://"):
279
+ image_obj = Image.open(requests.get(image, stream=True).raw)
280
+ elif image.startswith("file://"):
281
+ image_obj = Image.open(image[7:])
282
+ elif image.startswith("data:image"):
283
+ if "base64," in image:
284
+ _, base64_data = image.split("base64,", 1)
285
+ data = base64.b64decode(base64_data)
286
+ image_obj = Image.open(BytesIO(data))
287
+ else:
288
+ image_obj = Image.open(image)
289
+ if image_obj is None:
290
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
291
+ image = image_obj.convert("RGB")
292
+ ## resize
293
+ # if "resized_height" in ele and "resized_width" in ele:
294
+ # resized_height, resized_width = smart_resize(
295
+ # ele["resized_height"],
296
+ # ele["resized_width"],
297
+ # factor=size_factor,
298
+ # )
299
+ # else:
300
+ width, height = image.size
301
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
302
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
303
+ resized_height, resized_width = smart_resize(
304
+ height,
305
+ width,
306
+ factor=size_factor,
307
+ min_pixels=MIN_PIXELS,
308
+ max_pixels=MAX_PIXELS,
309
+ )
310
+ image = image.resize((resized_width, resized_height))
311
+
312
+ return image
313
+ ###
314
+