Ligeng-Zhu commited on
Commit
cb6db22
·
verified ·
1 Parent(s): 3383eeb

Upload files with `vila-upload`.

Browse files

Upload README.md
Upload tokenizer_utils.py
Upload builder.py
Upload mm_utils.py
Upload auto_processor.py
Upload modeling_vila.py

Files changed (6) hide show
  1. README.md +3 -3
  2. auto_processor.py +208 -57
  3. builder.py +4 -0
  4. mm_utils.py +5 -2
  5. modeling_vila.py +49 -13
  6. tokenizer_utils.py +8 -1
README.md CHANGED
@@ -67,7 +67,7 @@ model.eval()
67
  gpt_conv = [{
68
  "role": "user",
69
  "content": [
70
- {"type": "image", "path": "demo_images/demo_img_1.png"},
71
  {"type": "text", "text": "Describe this image."}
72
  ]
73
  }]
@@ -106,14 +106,14 @@ model.eval()
106
  gpt_conv1 = [{
107
  "role": "user",
108
  "content": [
109
- {"type": "image", "path": "demo_images/demo_img_1.png"},
110
  {"type": "text", "text": "Describe this image."}
111
  ]
112
  }]
113
  gpt_conv2 = [{
114
  "role": "user",
115
  "content": [
116
- {"type": "image", "path": "demo_images/demo_img_2.png"},
117
  {"type": "text", "text": "Describe this image for me. Provide a detailed description of the image."}
118
  ]
119
  }]
 
67
  gpt_conv = [{
68
  "role": "user",
69
  "content": [
70
+ {"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example.jpg"},
71
  {"type": "text", "text": "Describe this image."}
72
  ]
73
  }]
 
106
  gpt_conv1 = [{
107
  "role": "user",
108
  "content": [
109
+ {"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example.jpg"},
110
  {"type": "text", "text": "Describe this image."}
111
  ]
112
  }]
113
  gpt_conv2 = [{
114
  "role": "user",
115
  "content": [
116
+ {"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example_vqa.jpg"},
117
  {"type": "text", "text": "Describe this image for me. Provide a detailed description of the image."}
118
  ]
119
  }]
auto_processor.py CHANGED
@@ -3,8 +3,11 @@ import os
3
  import os.path as osp
4
  import warnings
5
  from collections import defaultdict
6
- from typing import List, Union
 
7
 
 
 
8
  import torch
9
  from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
10
  from transformers.feature_extraction_utils import BatchFeature
@@ -18,35 +21,73 @@ from .media import Image, Video, extract_media
18
  from .mm_utils import process_image, process_images
19
  from .tokenizer_utils import tokenize_conversation
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def fetch_image_url_or_fpath(url_or_fpath):
22
  if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
23
  import tempfile
 
24
  import requests
25
-
26
  # Download the image to a temporary file
27
  temp_dir = tempfile.mkdtemp()
28
  temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
29
-
30
  response = requests.get(url_or_fpath, stream=True)
31
  response.raise_for_status()
32
-
33
  with open(temp_file, "wb") as f:
34
  for chunk in response.iter_content(chunk_size=8192):
35
  f.write(chunk)
36
-
37
  return temp_file
38
  elif url_or_fpath.startswith("file://"):
39
  fpath = url_or_fpath.replace("file://", "")
40
  assert osp.exists(fpath), f"File {fpath} does not exist"
41
  return fpath
42
  elif osp.exists(url_or_fpath):
43
- assert osp.isfile(url_or_fpath), f"File {url_or_fpath} is not a file"
44
  return url_or_fpath
45
  else:
46
  raise ValueError(f"Unsupported image path: {url_or_fpath}")
47
-
48
 
49
- def __pad_fn(input_ids_list, padding_value=0, target_len=None, padding_side="left"):
 
50
  # tensor shape is (batch_size, seq_len)
51
  max_len = max([ids.shape[1] for ids in input_ids_list])
52
  if target_len is not None:
@@ -66,6 +107,36 @@ def __pad_fn(input_ids_list, padding_value=0, target_len=None, padding_side="lef
66
  return torch.cat(new_input_ids_list, dim=0)
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class VILAProcessorKwargs(ProcessingKwargs, total=False):
70
  _defaults = {
71
  "text_kwargs": {
@@ -74,8 +145,6 @@ class VILAProcessorKwargs(ProcessingKwargs, total=False):
74
  }
75
 
76
 
77
-
78
-
79
  class VILAProcessor(ProcessorMixin):
80
  # attributes = ["image_processor", "tokenizer"]
81
  attributes = []
@@ -85,16 +154,83 @@ class VILAProcessor(ProcessorMixin):
85
  # tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
86
 
87
  def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
88
- # self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
89
- # self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
90
  self.image_token = MEDIA_TOKENS["image"]
91
  self.video_token = MEDIA_TOKENS["video"]
92
  self.config = config
93
  self.image_processor = image_processor
94
  self.tokenizer = tokenizer
95
-
 
 
 
96
  super().__init__(image_processor, tokenizer, chat_template=chat_template)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  @classmethod
99
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
100
  if os.path.isdir(pretrained_model_name_or_path):
@@ -115,40 +251,59 @@ class VILAProcessor(ProcessorMixin):
115
  return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
116
 
117
  def __repr__(self):
118
- return (
119
- f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})"
120
- )
121
 
122
  def __call__(
123
  self,
124
- conversation,
125
- images: ImageInput = None,
126
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
127
- videos: VideoInput = None,
128
  **kwargs: Unpack[VILAProcessorKwargs],
129
  ) -> BatchFeature:
130
- if images is not None:
131
- warnings.warn("images is not supported in __call__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- input_ids = []
 
134
  media = defaultdict(list)
135
  media_config = defaultdict(dict)
136
  for conv in conversation:
137
- feat = self.__single_call__(conv, images, text, videos, **kwargs)
138
- input_ids.append(feat.input_ids)
 
139
  for name in feat.media:
140
  media[name] += feat.media[name]
141
  for name in feat.media_config:
142
  media_config[name].update(feat.media_config[name])
143
 
 
 
 
 
 
 
 
 
 
 
144
  return BatchFeature(
145
  data={
146
- # "input_ids": torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id),
147
- "input_ids": __pad_fn(
148
- input_ids,
149
- padding_value=self.tokenizer.pad_token_id,
150
- padding_side="left",
151
- ),
152
  "media": media,
153
  "media_config": media_config,
154
  }
@@ -195,9 +350,18 @@ class VILAProcessor(ProcessorMixin):
195
  ]
196
  else:
197
  raise ValueError(f"Unsupported media type: {name}")
198
- input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
199
- # Set up the generation config
200
- return BatchFeature(data={"input_ids": input_ids, "media": media, "media_config": media_config})
 
 
 
 
 
 
 
 
 
201
 
202
  def batch_decode(self, *args, **kwargs):
203
  """
@@ -235,38 +399,26 @@ class VILAProcessor(ProcessorMixin):
235
  image_processor_input_names = self.image_processor.model_input_names
236
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
237
 
238
- # inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
239
- def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
240
  vila_conv = []
241
  for chat in conversation:
242
  vila_chat = {"from": "", "value": []}
243
- if chat["role"] == "user":
244
  # user allows to input image and text
245
- vila_chat["from"] = "human"
246
- for content in chat["content"]:
247
- if content["type"] == "image":
248
- if "path" in content:
249
- # VILA style
250
- vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["path"])))
251
- elif "image" in content:
252
- # Qwen style
253
- vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["image"])))
254
- else:
255
- raise ValueError(f"Unsupported content type `image`: {content}, `image` and `path` are required")
256
- elif content["type"] == "text":
257
- vila_chat["value"].append(content["text"])
258
- # NOTE(ligeng): video supports are needed here
259
- else:
260
- raise ValueError(f"Unsupported content type: {content['type']}")
261
  elif chat["role"] == "assistant":
262
  vila_chat["from"] = "gpt"
263
- for content in chat["content"]:
264
- assert content["type"] == "text", f"Unsupported content type: {content['type']}"
265
- vila_chat["value"].append(content["text"])
266
  vila_conv.append(vila_chat)
267
 
268
  return vila_conv
269
 
 
 
 
270
 
271
  if __name__ == "__main__":
272
  # gpt style: user, assistant
@@ -301,7 +453,6 @@ if __name__ == "__main__":
301
  # print(model.config)
302
  # print(model.tokenizer)
303
  # print(res)
304
- # exit(0)
305
 
306
  processor = VILAProcessor(
307
  config=model.config,
 
3
  import os.path as osp
4
  import warnings
5
  from collections import defaultdict
6
+ from io import BytesIO
7
+ from typing import List, Optional, Union
8
 
9
+ import PIL.Image
10
+ import requests
11
  import torch
12
  from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
13
  from transformers.feature_extraction_utils import BatchFeature
 
21
  from .mm_utils import process_image, process_images
22
  from .tokenizer_utils import tokenize_conversation
23
 
24
+
25
+ def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
26
+ if pil_image.mode == "RGBA":
27
+ white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
28
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
29
+ return white_background
30
+ else:
31
+ return pil_image.convert("RGB")
32
+
33
+
34
+ def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
35
+ if "image" in ele:
36
+ image = ele["image"]
37
+ else:
38
+ image = ele["image_url"]
39
+ image_obj = None
40
+ if isinstance(image, PIL.Image.Image):
41
+ image_obj = image
42
+ elif image.startswith("http://") or image.startswith("https://"):
43
+ response = requests.get(image, stream=True)
44
+ image_obj = PIL.Image.open(BytesIO(response.content))
45
+ elif image.startswith("file://"):
46
+ image_obj = PIL.Image.open(image[7:])
47
+ elif image.startswith("data:image"):
48
+ if "base64," in image:
49
+ _, base64_data = image.split("base64,", 1)
50
+ data = base64.b64decode(base64_data)
51
+ image_obj = PIL.Image.open(BytesIO(data))
52
+ else:
53
+ image_obj = PIL.Image.open(image)
54
+ if image_obj is None:
55
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
56
+ image = to_rgb(image_obj)
57
+
58
+ return image
59
+
60
+
61
  def fetch_image_url_or_fpath(url_or_fpath):
62
  if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
63
  import tempfile
64
+
65
  import requests
66
+
67
  # Download the image to a temporary file
68
  temp_dir = tempfile.mkdtemp()
69
  temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
70
+
71
  response = requests.get(url_or_fpath, stream=True)
72
  response.raise_for_status()
73
+
74
  with open(temp_file, "wb") as f:
75
  for chunk in response.iter_content(chunk_size=8192):
76
  f.write(chunk)
77
+
78
  return temp_file
79
  elif url_or_fpath.startswith("file://"):
80
  fpath = url_or_fpath.replace("file://", "")
81
  assert osp.exists(fpath), f"File {fpath} does not exist"
82
  return fpath
83
  elif osp.exists(url_or_fpath):
84
+ assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
85
  return url_or_fpath
86
  else:
87
  raise ValueError(f"Unsupported image path: {url_or_fpath}")
 
88
 
89
+
90
+ def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
91
  # tensor shape is (batch_size, seq_len)
92
  max_len = max([ids.shape[1] for ids in input_ids_list])
93
  if target_len is not None:
 
107
  return torch.cat(new_input_ids_list, dim=0)
108
 
109
 
110
+ def extract_value_from_conv(chat):
111
+ value = []
112
+ if isinstance(chat["content"], str):
113
+ # vila_chat["value"].append(chat["content"])
114
+ value.append(chat["content"])
115
+ return value
116
+
117
+ # otherwise, it's a list of content
118
+ for content in chat["content"]:
119
+ if content["type"] == "image":
120
+ if "path" in content:
121
+ # VILA style, can be either filepath or http url
122
+ value.append(Image(fetch_image_url_or_fpath(content["path"])))
123
+ elif "image" in content:
124
+ # Qwen style
125
+ value.append(Image(fetch_image_url_or_fpath(content["image"])))
126
+ elif "image_pil" in content:
127
+ # Qwen style
128
+ assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of {media_key} must be PIL.Image.Image"
129
+ value.append(content["image_pil"])
130
+ else:
131
+ raise ValueError(f"Type = `image` , but no `path` or `image` in | {content=}, {conversation=}")
132
+ elif content["type"] == "text":
133
+ value.append(content["text"])
134
+ # NOTE(ligeng): video supports are needed here
135
+ else:
136
+ raise ValueError(f"Unsupported content type: {content['type']}")
137
+ return value
138
+
139
+
140
  class VILAProcessorKwargs(ProcessingKwargs, total=False):
141
  _defaults = {
142
  "text_kwargs": {
 
145
  }
146
 
147
 
 
 
148
  class VILAProcessor(ProcessorMixin):
149
  # attributes = ["image_processor", "tokenizer"]
150
  attributes = []
 
154
  # tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
155
 
156
  def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
 
 
157
  self.image_token = MEDIA_TOKENS["image"]
158
  self.video_token = MEDIA_TOKENS["video"]
159
  self.config = config
160
  self.image_processor = image_processor
161
  self.tokenizer = tokenizer
162
+ # self.pad_token_id = tokenizer.pad_token_id
163
+ self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0]
164
+ self.eos_token_id = self.tokenizer.eos_token_id
165
+ # self.pad_token_id = 151643
166
  super().__init__(image_processor, tokenizer, chat_template=chat_template)
167
 
168
+ @staticmethod
169
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
170
+ """
171
+ referernce from qwen_vl_utils
172
+ """
173
+ vision_infos = []
174
+ if isinstance(conversations[0], dict):
175
+ conversations = [conversations]
176
+ for conversation in conversations:
177
+ for message in conversation:
178
+ if isinstance(message["content"], list):
179
+ for ele in message["content"]:
180
+ if (
181
+ "image" in ele
182
+ or "image_url" in ele
183
+ or "video" in ele
184
+ or ele["type"] in ("image", "image_url", "video")
185
+ ):
186
+ vision_infos.append(ele)
187
+ return vision_infos
188
+
189
+ @staticmethod
190
+ def process_vision_info(
191
+ conversations: list[dict] | list[list[dict]],
192
+ return_video_kwargs: bool = False,
193
+ ) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
194
+ """
195
+ referernce from qwen_vl_utils
196
+ """
197
+ vision_infos = extract_vision_info(conversations)
198
+ ## Read images or videos
199
+ image_inputs = []
200
+ video_inputs = []
201
+ video_sample_fps_list = []
202
+ for vision_info in vision_infos:
203
+ if "image" in vision_info or "image_url" in vision_info:
204
+ image_inputs.append(fetch_image(vision_info))
205
+ elif "video" in vision_info:
206
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
207
+ video_sample_fps_list.append(video_sample_fps)
208
+ video_inputs.append(video_input)
209
+ else:
210
+ raise ValueError("image, image_url or video should in content.")
211
+ if len(image_inputs) == 0:
212
+ image_inputs = None
213
+ if len(video_inputs) == 0:
214
+ video_inputs = None
215
+ if return_video_kwargs:
216
+ return image_inputs, video_inputs, {"fps": video_sample_fps_list}
217
+ return image_inputs, video_inputs
218
+
219
+ @staticmethod
220
+ def move_data_to_device(cls, prompt_inputs):
221
+ def _move_data_to_device(item):
222
+ # wrap function grpo trainer _prepare_input
223
+ kwargs = {"device": cls.args.device}
224
+ if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
225
+ kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
226
+ return item.to(**kwargs)
227
+
228
+ prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
229
+ prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
230
+ if "image" in prompt_inputs.media:
231
+ prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
232
+ return prompt_inputs
233
+
234
  @classmethod
235
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
236
  if os.path.isdir(pretrained_model_name_or_path):
 
251
  return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
252
 
253
  def __repr__(self):
254
+ # NOTE(ligeng): hard coded image_processor to avoid serialization error. Dirty fix
255
+ return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
 
256
 
257
  def __call__(
258
  self,
259
+ conversation=None,
 
 
 
260
  **kwargs: Unpack[VILAProcessorKwargs],
261
  ) -> BatchFeature:
262
+ """
263
+ The `conv` will be look like
264
+ [
265
+ {
266
+ 'from': 'human',
267
+ 'value': [
268
+ <transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
269
+ 'What are the common elements in these pictures?'
270
+ ]
271
+ }
272
+ ]
273
+ and `conversation` will be a list of such `conv`s
274
+ """
275
+ if kwargs.get("text", None) is not None:
276
+ conversation = kwargs.get("text")
277
+ assert conversation is not None, "`conversation` or `text` is required"
278
+ padding_side = kwargs.get("padding_side", "left")
279
 
280
+ input_ids_list = []
281
+ attention_mask = []
282
  media = defaultdict(list)
283
  media_config = defaultdict(dict)
284
  for conv in conversation:
285
+ feat = self.__single_call__(conv, **kwargs)
286
+ input_ids_list.append(feat.input_ids)
287
+ attention_mask.append(feat.attention_mask)
288
  for name in feat.media:
289
  media[name] += feat.media[name]
290
  for name in feat.media_config:
291
  media_config[name].update(feat.media_config[name])
292
 
293
+ input_ids = pad_fn(
294
+ input_ids_list,
295
+ padding_value=self.pad_token_id,
296
+ padding_side=padding_side,
297
+ )
298
+ # ignore the pad token in the attention mask
299
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
300
+ attention_mask[input_ids == self.pad_token_id] = False
301
+ # print("[DEBUGAAA]", self.pad_token_id, self.tokenizer.pad_token_id); exit(0)
302
+
303
  return BatchFeature(
304
  data={
305
+ "input_ids": input_ids,
306
+ "attention_mask": attention_mask,
 
 
 
 
307
  "media": media,
308
  "media_config": media_config,
309
  }
 
350
  ]
351
  else:
352
  raise ValueError(f"Unsupported media type: {name}")
353
+
354
+ inputs = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True, return_ids_only=False)
355
+ input_ids = inputs.input_ids[0].cuda().unsqueeze(0)
356
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
357
+ return BatchFeature(
358
+ data={
359
+ "input_ids": input_ids,
360
+ "attention_mask": attention_mask,
361
+ "media": media,
362
+ "media_config": media_config,
363
+ }
364
+ )
365
 
366
  def batch_decode(self, *args, **kwargs):
367
  """
 
399
  image_processor_input_names = self.image_processor.model_input_names
400
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
401
 
402
+ def convert_gpt_conv_to_vila_conv(self, conversation):
 
403
  vila_conv = []
404
  for chat in conversation:
405
  vila_chat = {"from": "", "value": []}
406
+ if chat["role"] in ("user", "system"):
407
  # user allows to input image and text
408
+ vila_chat["from"] = "human" if chat["role"] == "user" else "system"
409
+ vila_chat["value"] = extract_value_from_conv(chat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  elif chat["role"] == "assistant":
411
  vila_chat["from"] = "gpt"
412
+ vila_chat["value"] = extract_value_from_conv(chat)
413
+ else:
414
+ raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
415
  vila_conv.append(vila_chat)
416
 
417
  return vila_conv
418
 
419
+ def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
420
+ return self.convert_gpt_conv_to_vila_conv(conversation)
421
+
422
 
423
  if __name__ == "__main__":
424
  # gpt style: user, assistant
 
453
  # print(model.config)
454
  # print(model.tokenizer)
455
  # print(res)
 
456
 
457
  processor = VILAProcessor(
458
  config=model.config,
builder.py CHANGED
@@ -33,6 +33,7 @@ from transformers import (
33
  PreTrainedModel,
34
  PreTrainedTokenizer,
35
  )
 
36
 
37
  # from .conversation import *
38
  from .conversation import SeparatorStyle, default_conversation
@@ -202,6 +203,9 @@ def build_llm_and_tokenizer(
202
  fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
203
  )
204
  else:
 
 
 
205
  llm = AutoModelForCausalLM.from_pretrained(
206
  model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
207
  )
 
33
  PreTrainedModel,
34
  PreTrainedTokenizer,
35
  )
36
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
37
 
38
  # from .conversation import *
39
  from .conversation import SeparatorStyle, default_conversation
 
203
  fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
204
  )
205
  else:
206
+ if is_deepspeed_zero3_enabled():
207
+ # NOTE: found by wei, need to pop out device_map when using zero3
208
+ kwargs.pop("device_map")
209
  llm = AutoModelForCausalLM.from_pretrained(
210
  model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
211
  )
mm_utils.py CHANGED
@@ -521,8 +521,11 @@ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False,
521
  return new_images
522
 
523
 
524
- def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
525
- return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
 
 
 
526
 
527
 
528
  def is_gemma_tokenizer(tokenizer):
 
521
  return new_images
522
 
523
 
524
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
525
+ if return_ids:
526
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
527
+ else:
528
+ return tokenizer(prompt, return_tensors=return_tensors)
529
 
530
 
531
  def is_gemma_tokenizer(tokenizer):
modeling_vila.py CHANGED
@@ -428,6 +428,12 @@ class VILAPretrainedModel(PreTrainedModel):
428
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
429
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
430
 
 
 
 
 
 
 
431
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
432
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
433
  # XGrammar tokenizer and grammar compiler
@@ -650,11 +656,9 @@ class VILAForCasualLM(VILAPretrainedModel):
650
  name = media_tokens[input_ids[k][pos].item()]
651
  input = media_embeds[name].popleft()
652
  label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
653
- # print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
654
- elif input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
655
  end = pos + 1
656
  pos = end
657
- # print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
658
  continue
659
  else:
660
  end = pos
@@ -662,7 +666,6 @@ class VILAForCasualLM(VILAPretrainedModel):
662
  end += 1
663
  input = text_embeds[k][pos:end]
664
  label = labels[k][pos:end]
665
- # print(f"[text] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
666
 
667
  inputs_mk.append(input)
668
  labels_mk.append(label)
@@ -1018,6 +1021,7 @@ class VILAForCasualLM(VILAPretrainedModel):
1018
  media: Optional[Dict[str, List[torch.Tensor]]] = None,
1019
  images: Optional[torch.FloatTensor] = None,
1020
  media_config: Optional[List] = None,
 
1021
  attention_mask: Optional[torch.Tensor] = None,
1022
  position_ids: Optional[torch.LongTensor] = None,
1023
  past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -1074,21 +1078,56 @@ class VILAForCasualLM(VILAPretrainedModel):
1074
 
1075
  return outputs
1076
 
1077
- @torch.inference_mode()
1078
  def generate(
1079
  self,
1080
  input_ids: Optional[torch.FloatTensor] = None,
1081
  media: Optional[Dict[str, List[torch.Tensor]]] = None,
1082
  media_config: Dict[str, Dict[str, Any]] = None,
1083
  attention_mask: Optional[torch.LongTensor] = None,
 
1084
  **generation_kwargs,
1085
- ):
 
1086
  if self.training:
1087
  warnings.warn(
1088
- "Model is in training mode, using default padding strategy to right. This is not recommended for generation."
1089
  )
 
 
 
 
 
 
 
 
 
 
 
1090
  inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1091
- return self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1092
 
1093
  @torch.inference_mode()
1094
  def generate_content(
@@ -1101,10 +1140,7 @@ class VILAForCasualLM(VILAPretrainedModel):
1101
  conversation = [{"from": "human", "value": prompt}]
1102
 
1103
  # Convert response format to logits processor
1104
- if response_format:
1105
- xgr_logits_processor = self.get_xgr_logits_processor(response_format)
1106
- else:
1107
- xgr_logits_processor = None
1108
 
1109
  # Extract media from the conversation
1110
 
@@ -1173,7 +1209,7 @@ class VILAForCasualLM(VILAPretrainedModel):
1173
  raise ValueError(f"Unsupported media type: {name}")
1174
 
1175
  # Tokenize the conversation
1176
- input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
1177
 
1178
  # Set up the generation config
1179
  generation_config = generation_config or self.default_generation_config
 
428
  # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
429
  NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
430
 
431
+ self.pad_token_list = (
432
+ self.tokenizer.pad_token_id,
433
+ self.tokenizer.eos_token_id,
434
+ self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
435
+ )
436
+
437
  # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
438
  self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
439
  # XGrammar tokenizer and grammar compiler
 
656
  name = media_tokens[input_ids[k][pos].item()]
657
  input = media_embeds[name].popleft()
658
  label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
659
+ elif input_ids[k][pos].item() in self.pad_token_list:
 
660
  end = pos + 1
661
  pos = end
 
662
  continue
663
  else:
664
  end = pos
 
666
  end += 1
667
  input = text_embeds[k][pos:end]
668
  label = labels[k][pos:end]
 
669
 
670
  inputs_mk.append(input)
671
  labels_mk.append(label)
 
1021
  media: Optional[Dict[str, List[torch.Tensor]]] = None,
1022
  images: Optional[torch.FloatTensor] = None,
1023
  media_config: Optional[List] = None,
1024
+ pixel_values: Optional[torch.FloatTensor] = None,
1025
  attention_mask: Optional[torch.Tensor] = None,
1026
  position_ids: Optional[torch.LongTensor] = None,
1027
  past_key_values: Optional[List[torch.FloatTensor]] = None,
 
1078
 
1079
  return outputs
1080
 
1081
+ # @torch.inference_mode()
1082
  def generate(
1083
  self,
1084
  input_ids: Optional[torch.FloatTensor] = None,
1085
  media: Optional[Dict[str, List[torch.Tensor]]] = None,
1086
  media_config: Dict[str, Dict[str, Any]] = None,
1087
  attention_mask: Optional[torch.LongTensor] = None,
1088
+ return_output_ids_only: bool = False,
1089
  **generation_kwargs,
1090
+ ) -> torch.LongTensor:
1091
+ model_training_status = False
1092
  if self.training:
1093
  warnings.warn(
1094
+ "Model is in training mode, using default padding strategy to right. This is not recommended for generation. We implicitly set the model to evaluation mode and restore the model training status after generation."
1095
  )
1096
+ self.eval()
1097
+ model_training_status = True
1098
+ """
1099
+ input_tokens: <image> describe the image
1100
+ media: [Tensor(1, 3, 384, 384), ]
1101
+ ----------->
1102
+ input_tokens: 36000 001 002 003 004
1103
+ input_emds: <media emd> 001 002 003 004
1104
+ """
1105
+ # TODO: there is still a padding left vs right issue unsovled here.
1106
+ # print("prev args:",input_ids.shape, media, media_config, None, attention_mask)
1107
  inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1108
+ # print("inputs_embeds", inputs_embeds.shape, inputs_embeds.mean(), inputs_embeds.std())
1109
+ # print("attention_mask", attention_mask.shape, attention_mask)
1110
+ output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1111
+ # print("output_ids", self.tokenizer.batch_decode(output_ids))
1112
+ # input("wait for debug")
1113
+ if return_output_ids_only:
1114
+ return_value = output_ids
1115
+ else:
1116
+ # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
1117
+ # print(f"[DEBUG REMOTE] input_ids: {input_ids.shape}, output_ids: {output_ids.shape} attention_mask: {attention_mask.shape} {generation_kwargs=}"); exit(0)
1118
+ generation_config = generation_kwargs.get("generation_config", None)
1119
+ if generation_config is not None:
1120
+ num_generations = generation_config.num_return_sequences
1121
+ repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
1122
+ return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
1123
+ else:
1124
+ return_value = torch.cat([input_ids, output_ids], dim=-1)
1125
+
1126
+ if model_training_status:
1127
+ # restore the model training status
1128
+ self.train()
1129
+
1130
+ return return_value
1131
 
1132
  @torch.inference_mode()
1133
  def generate_content(
 
1140
  conversation = [{"from": "human", "value": prompt}]
1141
 
1142
  # Convert response format to logits processor
1143
+ xgr_logits_processor = None
 
 
 
1144
 
1145
  # Extract media from the conversation
1146
 
 
1209
  raise ValueError(f"Unsupported media type: {name}")
1210
 
1211
  # Tokenize the conversation
1212
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
1213
 
1214
  # Set up the generation config
1215
  generation_config = generation_config or self.default_generation_config
tokenizer_utils.py CHANGED
@@ -68,13 +68,16 @@ def tokenize_conversation_legacy(
68
  return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
69
 
70
 
 
71
  def tokenize_conversation(
72
  messages: Sequence[Dict[str, str]],
73
  tokenizer: transformers.PreTrainedTokenizer,
74
  add_generation_prompt: bool = False,
75
  overrides: Optional[Dict[str, str]] = None,
76
  no_system_prompt: bool = False,
 
77
  ) -> torch.Tensor:
 
78
  # Normalize the conversation before tokenization
79
  for message in messages:
80
  message["value"] = message["value"].strip()
@@ -95,6 +98,10 @@ def tokenize_conversation(
95
  message["role"] = "user"
96
  elif m["from"] == "gpt":
97
  message["role"] = "assistant"
 
 
 
 
98
  else:
99
  raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
100
 
@@ -111,7 +118,7 @@ def tokenize_conversation(
111
  add_generation_prompt=add_generation_prompt,
112
  tokenize=False,
113
  )
114
- return tokenizer_image_token(text, tokenizer, return_tensors="pt")
115
 
116
 
117
  def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
 
68
  return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
69
 
70
 
71
+ # NOTE(ligeng): add a return typing to help code analyze
72
  def tokenize_conversation(
73
  messages: Sequence[Dict[str, str]],
74
  tokenizer: transformers.PreTrainedTokenizer,
75
  add_generation_prompt: bool = False,
76
  overrides: Optional[Dict[str, str]] = None,
77
  no_system_prompt: bool = False,
78
+ return_ids_only=True,
79
  ) -> torch.Tensor:
80
+ # print("messages", messages); input()
81
  # Normalize the conversation before tokenization
82
  for message in messages:
83
  message["value"] = message["value"].strip()
 
98
  message["role"] = "user"
99
  elif m["from"] == "gpt":
100
  message["role"] = "assistant"
101
+ elif m["from"] == "system":
102
+ message["role"] = "system"
103
+ if no_system_prompt:
104
+ raise ValueError("System prompt is not allowed when no_system_prompt is True.")
105
  else:
106
  raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
107
 
 
118
  add_generation_prompt=add_generation_prompt,
119
  tokenize=False,
120
  )
121
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only)
122
 
123
 
124
  def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: