Yukang commited on
Commit
d2144bd
·
verified ·
1 Parent(s): 34ea7b1

Upload 2 files

Browse files
Files changed (2) hide show
  1. media.py +1 -4
  2. modeling_vila.py +106 -11
media.py CHANGED
@@ -11,7 +11,7 @@ import requests
11
  from transformers import PretrainedConfig
12
 
13
  # from llava.constants import MEDIA_TOKENS
14
- # from llava.media import Image, Video
15
  # from llava.utils import make_list
16
  # from llava.utils.logging import logger
17
 
@@ -31,9 +31,6 @@ class Image(File):
31
  pass
32
 
33
 
34
- class Video(File):
35
- pass
36
-
37
  def make_list(obj: Any) -> List:
38
  return obj if isinstance(obj, list) else [obj]
39
 
 
11
  from transformers import PretrainedConfig
12
 
13
  # from llava.constants import MEDIA_TOKENS
14
+ from llava.media import Image, Video
15
  # from llava.utils import make_list
16
  # from llava.utils.logging import logger
17
 
 
31
  pass
32
 
33
 
 
 
 
34
  def make_list(obj: Any) -> List:
35
  return obj if isinstance(obj, list) else [obj]
36
 
modeling_vila.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import copy
2
  import json
3
  import logging
@@ -142,14 +143,97 @@ class VILAPretrainedModel(PreTrainedModel):
142
  self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
143
  ), "At least one of the components must be instantiated."
144
 
145
-
146
-
147
  @classmethod
148
- def save_pretrained(
149
- cls,
150
- ):
151
- raise NotImplementedError
152
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  @classmethod
155
  def from_pretrained(
@@ -202,6 +286,16 @@ class VILAPretrainedModel(PreTrainedModel):
202
  if getattr(self.config, "mm_projector_cfg", None) is None:
203
  self.config.mm_projector_cfg = self.mm_projector.config
204
 
 
 
 
 
 
 
 
 
 
 
205
  def get_vision_tower(self):
206
  vision_tower = getattr(self, "vision_tower", None)
207
  if type(vision_tower) is list:
@@ -408,7 +502,7 @@ class VILAForCasualLM(VILAPretrainedModel):
408
  if self.training:
409
  # Gather metainfo of media objects from all ranks
410
  info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
411
- infos = list(chain(*all_gather(info)))
412
 
413
  # The entire batch does not contain any media objects of this type.
414
  if not infos:
@@ -750,7 +844,7 @@ class VILAForCasualLM(VILAPretrainedModel):
750
  if images is not None:
751
  if media is not None:
752
  raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
753
- logger.warning("The 'images' argument is deprecated. Please use 'media' instead.")
754
  media = {"image": images}
755
 
756
  if media_config is None:
@@ -845,7 +939,7 @@ class VILAForCasualLM(VILAPretrainedModel):
845
  images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
846
  media[name] = [image for image in images]
847
  elif name == "video":
848
- if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
849
  media[name] = [
850
  process_images(
851
  images,
@@ -856,7 +950,7 @@ class VILAForCasualLM(VILAPretrainedModel):
856
  ).half()
857
  for images in media[name]
858
  ]
859
- elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
860
  self.config.image_processor = self.vision_tower.image_processor
861
  if type(self.config.s2_scales) is str:
862
  self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
@@ -930,3 +1024,4 @@ class VILAForCasualLM(VILAPretrainedModel):
930
  if generation_config.eos_token_id is None:
931
  generation_config.eos_token_id = self.tokenizer.eos_token_id
932
  return generation_config
 
 
1
+ import shutil
2
  import copy
3
  import json
4
  import logging
 
143
  self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
144
  ), "At least one of the components must be instantiated."
145
 
 
 
146
  @classmethod
147
+ def convert_vila_dev_ckpt_to_remote(self, model_path: str, output_dir:str = None, *model_args, **kwargs):
148
+ # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
149
+ from huggingface_hub import HfApi, snapshot_download
150
+
151
+ if os.path.isdir(model_path):
152
+ model_path = model_path
153
+ api = HfApi()
154
+ if api.repo_exists(model_path):
155
+ model_path = snapshot_download(model_path, local_dir=output_dir)
156
+ print("downloading HF model to", model_path)
157
+
158
+ cfg_path = os.path.join(model_path, "config.json")
159
+ config = json.load(open(cfg_path))
160
+ config["version"] = "2.0" # nvila tag
161
+ config["architectures"] = ["VILAForCasualLM"]
162
+ config["auto_map"] = {
163
+ "AutoConfig": "modeling_vila.VILAConfig",
164
+ "AutoModel": "modeling_vila.VILAForCasualLM",
165
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
166
+ }
167
+ config["model_type"] = "vila"
168
+ json.dump(config, open(cfg_path, "w"), indent=2)
169
+ self.copy_remote_py_files(model_path)
170
+
171
+ @classmethod
172
+ def copy_remote_py_files(cls, output_dir):
173
+ ## copy .py and REAMDE for next loading remote code
174
+ current_file_path = os.path.abspath(__file__)
175
+ current_folder = os.path.dirname(current_file_path)
176
+ for file_name in os.listdir(current_folder):
177
+ if file_name.endswith(".py"):
178
+ full_file_name = os.path.join(current_folder, file_name)
179
+ if os.path.isfile(full_file_name):
180
+ shutil.copy(full_file_name, output_dir)
181
+ print("[HF remote code] copying", full_file_name, "to", output_dir)
182
+
183
+ def save_pretrained(self, output_dir, state_dict=None, safe_serialization=None):
184
+ if state_dict is None:
185
+ # other wise fetch from deepspeed
186
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
187
+ state_dict = self.state_dict()
188
+
189
+ if getattr(self, "tokenizer", None):
190
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
191
+
192
+ if self.get_llm():
193
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
194
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
195
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
196
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
197
+ self.config.llm_cfg = self.llm.config
198
+
199
+ if self.get_vision_tower():
200
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
201
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
202
+ vision_tower_state_dict = OrderedDict(
203
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
204
+ )
205
+ self.vision_tower.vision_tower.save_pretrained(
206
+ os.path.join(output_dir, "vision_tower"),
207
+ state_dict=vision_tower_state_dict,
208
+ )
209
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
210
+ self.config.vision_tower_cfg = self.vision_tower.config
211
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
212
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
213
+ delattr(self.config.vision_tower_cfg, "auto_map")
214
+
215
+ if self.get_mm_projector():
216
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
217
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
218
+ mm_projector_state_dict = OrderedDict(
219
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
220
+ )
221
+ self.mm_projector.save_pretrained(
222
+ os.path.join(output_dir, "mm_projector"),
223
+ state_dict=mm_projector_state_dict,
224
+ )
225
+ self.config.mm_projector_cfg = self.mm_projector.config
226
+
227
+ ## update and save top-level config
228
+ self.config._name_or_path = output_dir
229
+ self.config.architectures = [self.__class__.__name__]
230
+ #print(self.config)
231
+ #self.config.save_pretrained(output_dir)
232
+
233
+ ## copy .py and REAMDE for next loading remote code
234
+ self.copy_remote_py_files(output_dir)
235
+
236
+
237
 
238
  @classmethod
239
  def from_pretrained(
 
286
  if getattr(self.config, "mm_projector_cfg", None) is None:
287
  self.config.mm_projector_cfg = self.mm_projector.config
288
 
289
+ def get_llm(self):
290
+ llm = getattr(self, "llm", None)
291
+ if type(llm) is list:
292
+ llm = llm[0]
293
+ return llm
294
+
295
+ def get_lm_head(self):
296
+ lm_head = getattr(self.get_llm(), "lm_head", None)
297
+ return lm_head
298
+
299
  def get_vision_tower(self):
300
  vision_tower = getattr(self, "vision_tower", None)
301
  if type(vision_tower) is list:
 
502
  if self.training:
503
  # Gather metainfo of media objects from all ranks
504
  info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
505
+ infos = list(chain(all_gather(info)))
506
 
507
  # The entire batch does not contain any media objects of this type.
508
  if not infos:
 
844
  if images is not None:
845
  if media is not None:
846
  raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
847
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
848
  media = {"image": images}
849
 
850
  if media_config is None:
 
939
  images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
940
  media[name] = [image for image in images]
941
  elif name == "video":
942
+ if False: #self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
943
  media[name] = [
944
  process_images(
945
  images,
 
950
  ).half()
951
  for images in media[name]
952
  ]
953
+ elif False: #self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
954
  self.config.image_processor = self.vision_tower.image_processor
955
  if type(self.config.s2_scales) is str:
956
  self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
 
1024
  if generation_config.eos_token_id is None:
1025
  generation_config.eos_token_id = self.tokenizer.eos_token_id
1026
  return generation_config
1027
+