Samoed commited on
Commit
02d3017
·
verified ·
1 Parent(s): 24370b0

Update gme_inference.py

Browse files
Files changed (1) hide show
  1. gme_inference.py +63 -111
gme_inference.py CHANGED
@@ -19,11 +19,11 @@ from transformers import (
19
  AutoProcessor,
20
  PreTrainedModel,
21
  Qwen2VLConfig,
22
- Qwen2VLModel,
23
  )
24
  import os
 
25
 
26
- # Define a config class for our model.
27
  class GmeQwen2VLConfig(Qwen2VLConfig):
28
  model_type: str = "gme_qwen2_vl"
29
 
@@ -39,11 +39,8 @@ class GmeQwen2VLConfig(Qwen2VLConfig):
39
  self.min_image_tokens = min_image_tokens
40
  self.max_image_tokens = max_image_tokens
41
  self.max_length = max_length
42
- self.device = device
43
- AutoConfig.register("gme_qwen2_vl", GmeQwen2VLConfig)
44
 
45
 
46
- # Define the model class so that it can be loaded by AutoModel.from_pretrained.
47
  class GmeQwen2VLForVision2Seq(PreTrainedModel):
48
  config_class = GmeQwen2VLConfig
49
  base_model_prefix: str = "base"
@@ -51,29 +48,21 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
51
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
52
  super().__init__(config)
53
  model_name: str = getattr(config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
54
- # Load the underlying vision-to-sequence model.
55
- self.base = Qwen2VLModel.from_pretrained(
56
- model_name, trust_remote_code=True, **kwargs
57
- )
58
  self.normalize: bool = True
59
- self.device: str = config.device
60
 
61
  min_pixels: int = config.min_image_tokens * 28 * 28
62
  max_pixels: int = config.max_image_tokens * 28 * 28
 
63
  self.max_length: int = config.max_length
64
-
65
  self.processor = AutoProcessor.from_pretrained(
66
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
67
  )
68
- self.processor.tokenizer.padding_side = "right"
69
  self.defualt_instruction: str = "You are a helpful assistant."
70
  self.sep: str = " "
71
 
72
- @classmethod
73
- def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> GmeQwen2VLForVision2Seq:
74
- config = kwargs.pop("config", GmeQwen2VLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))
75
- return cls(config, **kwargs)
76
-
77
  def forward(
78
  self,
79
  input_ids: Optional[torch.LongTensor] = None,
@@ -82,9 +71,11 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
82
  past_key_values: Optional[List[torch.FloatTensor]] = None,
83
  inputs_embeds: Optional[torch.FloatTensor] = None,
84
  pixel_values: Optional[torch.Tensor] = None,
 
85
  image_grid_thw: Optional[torch.LongTensor] = None,
 
86
  pooling_mask: Optional[torch.LongTensor] = None,
87
- **kwargs: Any,
88
  ) -> torch.Tensor:
89
  if inputs_embeds is None:
90
  inputs_embeds = self.base.model.embed_tokens(input_ids)
@@ -93,6 +84,11 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
93
  image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
94
  image_mask = input_ids == self.base.config.image_token_id
95
  inputs_embeds[image_mask] = image_embeds
 
 
 
 
 
96
  if attention_mask is not None:
97
  attention_mask = attention_mask.to(inputs_embeds.device)
98
 
@@ -105,48 +101,37 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
105
  )
106
 
107
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
108
- left_padding: bool = (pooling_mask[:, -1].sum() == pooling_mask.shape[0])
109
  if left_padding:
110
  embeddings = outputs.last_hidden_state[:, -1]
111
  else:
112
  sequence_lengths = pooling_mask.sum(dim=1) - 1
113
  batch_size = outputs.last_hidden_state.shape[0]
114
- embeddings = outputs.last_hidden_state[
115
- torch.arange(batch_size, device=outputs.last_hidden_state.device),
116
- sequence_lengths,
117
- ]
118
  if self.normalize:
119
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
120
  return embeddings.contiguous()
121
 
122
- def embed(
123
- self,
124
- texts: List[str],
125
- images: List[Image.Image],
126
- is_query: bool = True,
127
- instruction: Optional[str] = None,
128
- **kwargs: Any,
129
- ) -> torch.Tensor:
130
  self.base.to(self.device)
131
- input_texts: List[str] = []
132
- input_images: List[Image.Image] = []
133
  for t, i in zip(texts, images):
134
  if not is_query or instruction is None:
135
  instruction = self.defualt_instruction
136
- input_str: str = ""
137
  if i is None:
138
  input_images = None # All examples in the same batch are consistent
139
  else:
140
- input_str += "<|vision_start|><|image_pad|><|vision_end|>"
141
  i = fetch_image(i)
142
  input_images.append(i)
143
  if t is not None:
144
  input_str += t
145
- msg: str = (
146
- f"<|im_start|>system\n{instruction}<|im_end|>\n"
147
- f"<|im_start|>user\n{input_str}<|im_end|>\n"
148
- f"<|im_start|>assistant\n<|endoftext|>"
149
- )
150
  input_texts.append(msg)
151
 
152
  inputs = self.processor(
@@ -155,22 +140,22 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
155
  padding=True,
156
  truncation=True,
157
  max_length=self.max_length,
158
- return_tensors="pt",
159
  )
160
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
161
  with torch.no_grad():
162
  embeddings = self.forward(**inputs)
163
  return embeddings
164
 
165
- def encode(self, sentences: List[str], **kwargs: Any) -> torch.Tensor:
166
- # When no images are provided, we pass a list of Nones.
167
- return self.embed(texts=sentences, images=[None] * len(sentences), **kwargs)
168
 
169
- def encode_queries(self, queries: List[str], **kwargs: Any) -> torch.Tensor:
170
- return self.encode(queries, **kwargs)
 
171
 
172
- def encode_corpus(self, corpus: Union[Dict[str, List[str]], List[Dict[str, str]]], **kwargs: Any) -> torch.Tensor:
173
- if isinstance(corpus, dict):
174
  sentences = [
175
  (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
176
  if "title" in corpus
@@ -182,49 +167,56 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
182
  (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
183
  for doc in corpus
184
  ]
185
- return self.encode(sentences, is_query=False, **kwargs)
 
186
 
187
- def get_image_embeddings(self, images: Union[List[Image.Image], DataLoader], **kwargs: Any) -> torch.Tensor:
188
  return self.get_fused_embeddings(images=images, **kwargs)
189
 
190
- def get_text_embeddings(self, texts: List[str], **kwargs: Any) -> torch.Tensor:
191
  return self.get_fused_embeddings(texts=texts, **kwargs)
192
 
193
-
194
- def get_fused_embeddings(
195
- self,
196
- texts: Optional[List[str]] = None,
197
- images: Optional[Union[List[Image.Image], DataLoader]] = None,
198
- **kwargs: Any,
199
- ) -> torch.Tensor:
200
  if isinstance(images, DataLoader):
201
  image_loader = images
202
  batch_size = image_loader.batch_size
203
  image_loader.dataset.transform = None
204
  else:
205
- batch_size = kwargs.pop("batch_size", 32)
206
  if images is None:
207
- # If texts are provided without images, create dummy image batches.
208
- image_loader = [None] * ((len(texts) + batch_size - 1) // batch_size)
209
  else:
210
- image_loader = images
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- n_batch: int = (len(texts) // batch_size + int(len(texts) % batch_size > 0)) if texts is not None else len(image_loader)
213
- all_embeddings: List[torch.Tensor] = []
214
  none_batch = [None] * batch_size
215
- show_progress_bar: bool = kwargs.pop("show_progress_bar", True)
216
- pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc="encode")
217
  for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
218
- text_batch: List[Optional[str]] = none_batch if texts is None else texts[n: n + batch_size]
219
  img_batch = none_batch if img_batch is None else img_batch
220
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
221
  pbar.update(1)
222
  all_embeddings.append(embeddings.cpu())
223
  pbar.close()
224
- return torch.cat(all_embeddings, dim=0)
 
225
 
226
- from transformers import AutoModelForVision2Seq
227
- AutoModelForVision2Seq.register(GmeQwen2VLConfig, GmeQwen2VLForVision2Seq)
228
 
229
  # Utility functions (copied from your vision processing code)
230
  IMAGE_FACTOR: int = 28
@@ -309,43 +301,3 @@ def fetch_image(image: Union[str, Image.Image], size_factor: int = IMAGE_FACTOR)
309
  )
310
  image = image.resize((resized_width, resized_height))
311
  return image
312
-
313
-
314
- # # For backward compatibility, you can add a from_pretrained classmethod.
315
- # @classmethod
316
- # def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> GmeQwen2VLForVision2Seq:
317
- # config = GmeQwen2VLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
318
- # return cls(config, **kwargs)
319
-
320
-
321
- # # Monkey-patch the from_pretrained method to our class so that
322
- # # one can load the model with AutoModel.from_pretrained.
323
- # GmeQwen2VLForVision2Seq.from_pretrained = from_pretrained.__get__(GmeQwen2VLForVision2Seq)
324
-
325
-
326
- if __name__ == "__main__":
327
- texts = [
328
- "What kind of car is this?",
329
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
330
- ]
331
- images = [
332
- "https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg",
333
- "https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg",
334
- ]
335
-
336
- # You can now load your model with AutoModel as long as your repository's config JSON has the "architectures" field set.
337
- model = AutoModel.from_pretrained("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
338
- # Alternatively, load it directly via our class:
339
- # model = GmeQwen2VLForVision2Seq.from_pretrained("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
340
-
341
- # Single-modal embedding examples:
342
- e_text = model.get_text_embeddings(texts=texts)
343
- e_image = model.get_image_embeddings(images=images)
344
- print("Text-Image similarity:", (e_text * e_image).sum(-1))
345
- # Example with different instruction:
346
- e_query = model.get_text_embeddings(texts=texts, instruction="Find an image that matches the given text.")
347
- e_corpus = model.get_image_embeddings(images=images, is_query=False)
348
- print("Query-Corpus similarity:", (e_query * e_corpus).sum(-1))
349
- # Fused-modal embedding:
350
- e_fused = model.get_fused_embeddings(texts=texts, images=images)
351
- print("Fused-modal similarity:", (e_fused[0] * e_fused[1]).sum())
 
19
  AutoProcessor,
20
  PreTrainedModel,
21
  Qwen2VLConfig,
22
+ Qwen2VLForConditionalGeneration,
23
  )
24
  import os
25
+ from collections.abc import Iterable
26
 
 
27
  class GmeQwen2VLConfig(Qwen2VLConfig):
28
  model_type: str = "gme_qwen2_vl"
29
 
 
39
  self.min_image_tokens = min_image_tokens
40
  self.max_image_tokens = max_image_tokens
41
  self.max_length = max_length
 
 
42
 
43
 
 
44
  class GmeQwen2VLForVision2Seq(PreTrainedModel):
45
  config_class = GmeQwen2VLConfig
46
  base_model_prefix: str = "base"
 
48
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
49
  super().__init__(config)
50
  model_name: str = getattr(config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
51
+
52
+ self.base = Qwen2VLForConditionalGeneration(config)
 
 
53
  self.normalize: bool = True
 
54
 
55
  min_pixels: int = config.min_image_tokens * 28 * 28
56
  max_pixels: int = config.max_image_tokens * 28 * 28
57
+
58
  self.max_length: int = config.max_length
 
59
  self.processor = AutoProcessor.from_pretrained(
60
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
61
  )
62
+ self.processor.tokenizer.padding_side = 'right'
63
  self.defualt_instruction: str = "You are a helpful assistant."
64
  self.sep: str = " "
65
 
 
 
 
 
 
66
  def forward(
67
  self,
68
  input_ids: Optional[torch.LongTensor] = None,
 
71
  past_key_values: Optional[List[torch.FloatTensor]] = None,
72
  inputs_embeds: Optional[torch.FloatTensor] = None,
73
  pixel_values: Optional[torch.Tensor] = None,
74
+ # pixel_values_videos: Optional[torch.FloatTensor] = None,
75
  image_grid_thw: Optional[torch.LongTensor] = None,
76
+ # video_grid_thw: Optional[torch.LongTensor] = None,
77
  pooling_mask: Optional[torch.LongTensor] = None,
78
+ **kwargs
79
  ) -> torch.Tensor:
80
  if inputs_embeds is None:
81
  inputs_embeds = self.base.model.embed_tokens(input_ids)
 
84
  image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
85
  image_mask = input_ids == self.base.config.image_token_id
86
  inputs_embeds[image_mask] = image_embeds
87
+ # if pixel_values_videos is not None:
88
+ # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
89
+ # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
90
+ # video_mask = input_ids == self.base.config.video_token_id
91
+ # inputs_embeds[video_mask] = video_embeds
92
  if attention_mask is not None:
93
  attention_mask = attention_mask.to(inputs_embeds.device)
94
 
 
101
  )
102
 
103
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
104
+ left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
105
  if left_padding:
106
  embeddings = outputs.last_hidden_state[:, -1]
107
  else:
108
  sequence_lengths = pooling_mask.sum(dim=1) - 1
109
  batch_size = outputs.last_hidden_state.shape[0]
110
+ embeddings = outputs.last_hidden_state[torch.arange(
111
+ batch_size, device=outputs.last_hidden_state.device
112
+ ), sequence_lengths]
 
113
  if self.normalize:
114
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
115
  return embeddings.contiguous()
116
 
117
+
118
+ def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
 
 
 
 
 
 
119
  self.base.to(self.device)
120
+ # Inputs must be batched
121
+ input_texts, input_images = list(), list()
122
  for t, i in zip(texts, images):
123
  if not is_query or instruction is None:
124
  instruction = self.defualt_instruction
125
+ input_str = ''
126
  if i is None:
127
  input_images = None # All examples in the same batch are consistent
128
  else:
129
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
130
  i = fetch_image(i)
131
  input_images.append(i)
132
  if t is not None:
133
  input_str += t
134
+ msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
 
 
 
 
135
  input_texts.append(msg)
136
 
137
  inputs = self.processor(
 
140
  padding=True,
141
  truncation=True,
142
  max_length=self.max_length,
143
+ return_tensors='pt'
144
  )
145
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
146
  with torch.no_grad():
147
  embeddings = self.forward(**inputs)
148
  return embeddings
149
 
150
+ def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
151
+ return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
 
152
 
153
+ def encode_queries(self, queries: List[str], **kwargs):
154
+ embeddings = self.encode(queries, **kwargs)
155
+ return embeddings
156
 
157
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
158
+ if type(corpus) is dict:
159
  sentences = [
160
  (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
161
  if "title" in corpus
 
167
  (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
168
  for doc in corpus
169
  ]
170
+ embeddings = self.encode(sentences, is_query=False, **kwargs)
171
+ return embeddings
172
 
173
+ def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
174
  return self.get_fused_embeddings(images=images, **kwargs)
175
 
176
+ def get_text_embeddings(self, texts: list[str], **kwargs):
177
  return self.get_fused_embeddings(texts=texts, **kwargs)
178
 
179
+ def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
 
 
 
 
 
 
180
  if isinstance(images, DataLoader):
181
  image_loader = images
182
  batch_size = image_loader.batch_size
183
  image_loader.dataset.transform = None
184
  else:
185
+ batch_size = kwargs.pop('batch_size', 32)
186
  if images is None:
187
+ image_loader = None
 
188
  else:
189
+ image_loader = DataLoader(
190
+ images,
191
+ batch_size=batch_size,
192
+ shuffle=False,
193
+ collate_fn=custom_collate_fn,
194
+ num_workers=min(math.floor(os.cpu_count() / 2), 8),
195
+ )
196
+
197
+ if texts is None:
198
+ assert image_loader is not None
199
+ n_batch = len(image_loader)
200
+ else:
201
+ n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
202
+ image_loader = image_loader or [None] * n_batch
203
 
204
+ all_embeddings = list()
 
205
  none_batch = [None] * batch_size
206
+ show_progress_bar = kwargs.pop('show_progress_bar', True)
207
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
208
  for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
209
+ text_batch = none_batch if texts is None else texts[n: n+batch_size]
210
  img_batch = none_batch if img_batch is None else img_batch
211
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
212
  pbar.update(1)
213
  all_embeddings.append(embeddings.cpu())
214
  pbar.close()
215
+ all_embeddings = torch.cat(all_embeddings, dim=0)
216
+ return all_embeddings
217
 
218
+ def custom_collate_fn(batch):
219
+ return batch
220
 
221
  # Utility functions (copied from your vision processing code)
222
  IMAGE_FACTOR: int = 28
 
301
  )
302
  image = image.resize((resized_width, resized_height))
303
  return image