Upload 2 files
Browse files- media.py +1 -4
- modeling_vila.py +106 -11
media.py
CHANGED
@@ -11,7 +11,7 @@ import requests
|
|
11 |
from transformers import PretrainedConfig
|
12 |
|
13 |
# from llava.constants import MEDIA_TOKENS
|
14 |
-
|
15 |
# from llava.utils import make_list
|
16 |
# from llava.utils.logging import logger
|
17 |
|
@@ -31,9 +31,6 @@ class Image(File):
|
|
31 |
pass
|
32 |
|
33 |
|
34 |
-
class Video(File):
|
35 |
-
pass
|
36 |
-
|
37 |
def make_list(obj: Any) -> List:
|
38 |
return obj if isinstance(obj, list) else [obj]
|
39 |
|
|
|
11 |
from transformers import PretrainedConfig
|
12 |
|
13 |
# from llava.constants import MEDIA_TOKENS
|
14 |
+
from llava.media import Image, Video
|
15 |
# from llava.utils import make_list
|
16 |
# from llava.utils.logging import logger
|
17 |
|
|
|
31 |
pass
|
32 |
|
33 |
|
|
|
|
|
|
|
34 |
def make_list(obj: Any) -> List:
|
35 |
return obj if isinstance(obj, list) else [obj]
|
36 |
|
modeling_vila.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import copy
|
2 |
import json
|
3 |
import logging
|
@@ -142,14 +143,97 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
142 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
143 |
), "At least one of the components must be instantiated."
|
144 |
|
145 |
-
|
146 |
-
|
147 |
@classmethod
|
148 |
-
def
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
@classmethod
|
155 |
def from_pretrained(
|
@@ -202,6 +286,16 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
202 |
if getattr(self.config, "mm_projector_cfg", None) is None:
|
203 |
self.config.mm_projector_cfg = self.mm_projector.config
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
def get_vision_tower(self):
|
206 |
vision_tower = getattr(self, "vision_tower", None)
|
207 |
if type(vision_tower) is list:
|
@@ -408,7 +502,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
408 |
if self.training:
|
409 |
# Gather metainfo of media objects from all ranks
|
410 |
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
411 |
-
infos = list(chain(
|
412 |
|
413 |
# The entire batch does not contain any media objects of this type.
|
414 |
if not infos:
|
@@ -750,7 +844,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
750 |
if images is not None:
|
751 |
if media is not None:
|
752 |
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
753 |
-
|
754 |
media = {"image": images}
|
755 |
|
756 |
if media_config is None:
|
@@ -845,7 +939,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
845 |
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
846 |
media[name] = [image for image in images]
|
847 |
elif name == "video":
|
848 |
-
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
849 |
media[name] = [
|
850 |
process_images(
|
851 |
images,
|
@@ -856,7 +950,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
856 |
).half()
|
857 |
for images in media[name]
|
858 |
]
|
859 |
-
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
860 |
self.config.image_processor = self.vision_tower.image_processor
|
861 |
if type(self.config.s2_scales) is str:
|
862 |
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
@@ -930,3 +1024,4 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
930 |
if generation_config.eos_token_id is None:
|
931 |
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
932 |
return generation_config
|
|
|
|
1 |
+
import shutil
|
2 |
import copy
|
3 |
import json
|
4 |
import logging
|
|
|
143 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
144 |
), "At least one of the components must be instantiated."
|
145 |
|
|
|
|
|
146 |
@classmethod
|
147 |
+
def convert_vila_dev_ckpt_to_remote(self, model_path: str, output_dir:str = None, *model_args, **kwargs):
|
148 |
+
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
149 |
+
from huggingface_hub import HfApi, snapshot_download
|
150 |
+
|
151 |
+
if os.path.isdir(model_path):
|
152 |
+
model_path = model_path
|
153 |
+
api = HfApi()
|
154 |
+
if api.repo_exists(model_path):
|
155 |
+
model_path = snapshot_download(model_path, local_dir=output_dir)
|
156 |
+
print("downloading HF model to", model_path)
|
157 |
+
|
158 |
+
cfg_path = os.path.join(model_path, "config.json")
|
159 |
+
config = json.load(open(cfg_path))
|
160 |
+
config["version"] = "2.0" # nvila tag
|
161 |
+
config["architectures"] = ["VILAForCasualLM"]
|
162 |
+
config["auto_map"] = {
|
163 |
+
"AutoConfig": "modeling_vila.VILAConfig",
|
164 |
+
"AutoModel": "modeling_vila.VILAForCasualLM",
|
165 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
|
166 |
+
}
|
167 |
+
config["model_type"] = "vila"
|
168 |
+
json.dump(config, open(cfg_path, "w"), indent=2)
|
169 |
+
self.copy_remote_py_files(model_path)
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def copy_remote_py_files(cls, output_dir):
|
173 |
+
## copy .py and REAMDE for next loading remote code
|
174 |
+
current_file_path = os.path.abspath(__file__)
|
175 |
+
current_folder = os.path.dirname(current_file_path)
|
176 |
+
for file_name in os.listdir(current_folder):
|
177 |
+
if file_name.endswith(".py"):
|
178 |
+
full_file_name = os.path.join(current_folder, file_name)
|
179 |
+
if os.path.isfile(full_file_name):
|
180 |
+
shutil.copy(full_file_name, output_dir)
|
181 |
+
print("[HF remote code] copying", full_file_name, "to", output_dir)
|
182 |
+
|
183 |
+
def save_pretrained(self, output_dir, state_dict=None, safe_serialization=None):
|
184 |
+
if state_dict is None:
|
185 |
+
# other wise fetch from deepspeed
|
186 |
+
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
|
187 |
+
state_dict = self.state_dict()
|
188 |
+
|
189 |
+
if getattr(self, "tokenizer", None):
|
190 |
+
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
|
191 |
+
|
192 |
+
if self.get_llm():
|
193 |
+
print(f"saving llm to {osp.join(output_dir, 'llm')}")
|
194 |
+
self.llm.config._name_or_path = osp.join(output_dir, "llm")
|
195 |
+
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
|
196 |
+
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
|
197 |
+
self.config.llm_cfg = self.llm.config
|
198 |
+
|
199 |
+
if self.get_vision_tower():
|
200 |
+
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
|
201 |
+
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
|
202 |
+
vision_tower_state_dict = OrderedDict(
|
203 |
+
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
|
204 |
+
)
|
205 |
+
self.vision_tower.vision_tower.save_pretrained(
|
206 |
+
os.path.join(output_dir, "vision_tower"),
|
207 |
+
state_dict=vision_tower_state_dict,
|
208 |
+
)
|
209 |
+
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
|
210 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
211 |
+
if hasattr(self.config.vision_tower_cfg, "auto_map"):
|
212 |
+
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
|
213 |
+
delattr(self.config.vision_tower_cfg, "auto_map")
|
214 |
+
|
215 |
+
if self.get_mm_projector():
|
216 |
+
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
|
217 |
+
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
|
218 |
+
mm_projector_state_dict = OrderedDict(
|
219 |
+
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
|
220 |
+
)
|
221 |
+
self.mm_projector.save_pretrained(
|
222 |
+
os.path.join(output_dir, "mm_projector"),
|
223 |
+
state_dict=mm_projector_state_dict,
|
224 |
+
)
|
225 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
226 |
+
|
227 |
+
## update and save top-level config
|
228 |
+
self.config._name_or_path = output_dir
|
229 |
+
self.config.architectures = [self.__class__.__name__]
|
230 |
+
#print(self.config)
|
231 |
+
#self.config.save_pretrained(output_dir)
|
232 |
+
|
233 |
+
## copy .py and REAMDE for next loading remote code
|
234 |
+
self.copy_remote_py_files(output_dir)
|
235 |
+
|
236 |
+
|
237 |
|
238 |
@classmethod
|
239 |
def from_pretrained(
|
|
|
286 |
if getattr(self.config, "mm_projector_cfg", None) is None:
|
287 |
self.config.mm_projector_cfg = self.mm_projector.config
|
288 |
|
289 |
+
def get_llm(self):
|
290 |
+
llm = getattr(self, "llm", None)
|
291 |
+
if type(llm) is list:
|
292 |
+
llm = llm[0]
|
293 |
+
return llm
|
294 |
+
|
295 |
+
def get_lm_head(self):
|
296 |
+
lm_head = getattr(self.get_llm(), "lm_head", None)
|
297 |
+
return lm_head
|
298 |
+
|
299 |
def get_vision_tower(self):
|
300 |
vision_tower = getattr(self, "vision_tower", None)
|
301 |
if type(vision_tower) is list:
|
|
|
502 |
if self.training:
|
503 |
# Gather metainfo of media objects from all ranks
|
504 |
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
505 |
+
infos = list(chain(all_gather(info)))
|
506 |
|
507 |
# The entire batch does not contain any media objects of this type.
|
508 |
if not infos:
|
|
|
844 |
if images is not None:
|
845 |
if media is not None:
|
846 |
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
847 |
+
print("The 'images' argument is deprecated. Please use 'media' instead.")
|
848 |
media = {"image": images}
|
849 |
|
850 |
if media_config is None:
|
|
|
939 |
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
940 |
media[name] = [image for image in images]
|
941 |
elif name == "video":
|
942 |
+
if False: #self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
943 |
media[name] = [
|
944 |
process_images(
|
945 |
images,
|
|
|
950 |
).half()
|
951 |
for images in media[name]
|
952 |
]
|
953 |
+
elif False: #self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
954 |
self.config.image_processor = self.vision_tower.image_processor
|
955 |
if type(self.config.s2_scales) is str:
|
956 |
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
|
|
1024 |
if generation_config.eos_token_id is None:
|
1025 |
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
1026 |
return generation_config
|
1027 |
+
|