Upload files with `vila-upload`.
Browse filesUpload README.md
Upload tokenizer_utils.py
Upload builder.py
Upload mm_utils.py
Upload auto_processor.py
Upload modeling_vila.py
- README.md +3 -3
- auto_processor.py +208 -57
- builder.py +4 -0
- mm_utils.py +5 -2
- modeling_vila.py +49 -13
- tokenizer_utils.py +8 -1
README.md
CHANGED
@@ -67,7 +67,7 @@ model.eval()
|
|
67 |
gpt_conv = [{
|
68 |
"role": "user",
|
69 |
"content": [
|
70 |
-
{"type": "image", "path": "
|
71 |
{"type": "text", "text": "Describe this image."}
|
72 |
]
|
73 |
}]
|
@@ -106,14 +106,14 @@ model.eval()
|
|
106 |
gpt_conv1 = [{
|
107 |
"role": "user",
|
108 |
"content": [
|
109 |
-
{"type": "image", "path": "
|
110 |
{"type": "text", "text": "Describe this image."}
|
111 |
]
|
112 |
}]
|
113 |
gpt_conv2 = [{
|
114 |
"role": "user",
|
115 |
"content": [
|
116 |
-
{"type": "image", "path": "
|
117 |
{"type": "text", "text": "Describe this image for me. Provide a detailed description of the image."}
|
118 |
]
|
119 |
}]
|
|
|
67 |
gpt_conv = [{
|
68 |
"role": "user",
|
69 |
"content": [
|
70 |
+
{"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example.jpg"},
|
71 |
{"type": "text", "text": "Describe this image."}
|
72 |
]
|
73 |
}]
|
|
|
106 |
gpt_conv1 = [{
|
107 |
"role": "user",
|
108 |
"content": [
|
109 |
+
{"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example.jpg"},
|
110 |
{"type": "text", "text": "Describe this image."}
|
111 |
]
|
112 |
}]
|
113 |
gpt_conv2 = [{
|
114 |
"role": "user",
|
115 |
"content": [
|
116 |
+
{"type": "image", "path": "https://nvlabs.github.io/VILA/asset/example_vqa.jpg"},
|
117 |
{"type": "text", "text": "Describe this image for me. Provide a detailed description of the image."}
|
118 |
]
|
119 |
}]
|
auto_processor.py
CHANGED
@@ -3,8 +3,11 @@ import os
|
|
3 |
import os.path as osp
|
4 |
import warnings
|
5 |
from collections import defaultdict
|
6 |
-
from
|
|
|
7 |
|
|
|
|
|
8 |
import torch
|
9 |
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
|
10 |
from transformers.feature_extraction_utils import BatchFeature
|
@@ -18,35 +21,73 @@ from .media import Image, Video, extract_media
|
|
18 |
from .mm_utils import process_image, process_images
|
19 |
from .tokenizer_utils import tokenize_conversation
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def fetch_image_url_or_fpath(url_or_fpath):
|
22 |
if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
|
23 |
import tempfile
|
|
|
24 |
import requests
|
25 |
-
|
26 |
# Download the image to a temporary file
|
27 |
temp_dir = tempfile.mkdtemp()
|
28 |
temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
|
29 |
-
|
30 |
response = requests.get(url_or_fpath, stream=True)
|
31 |
response.raise_for_status()
|
32 |
-
|
33 |
with open(temp_file, "wb") as f:
|
34 |
for chunk in response.iter_content(chunk_size=8192):
|
35 |
f.write(chunk)
|
36 |
-
|
37 |
return temp_file
|
38 |
elif url_or_fpath.startswith("file://"):
|
39 |
fpath = url_or_fpath.replace("file://", "")
|
40 |
assert osp.exists(fpath), f"File {fpath} does not exist"
|
41 |
return fpath
|
42 |
elif osp.exists(url_or_fpath):
|
43 |
-
assert osp.isfile(url_or_fpath), f"File {url_or_fpath}
|
44 |
return url_or_fpath
|
45 |
else:
|
46 |
raise ValueError(f"Unsupported image path: {url_or_fpath}")
|
47 |
-
|
48 |
|
49 |
-
|
|
|
50 |
# tensor shape is (batch_size, seq_len)
|
51 |
max_len = max([ids.shape[1] for ids in input_ids_list])
|
52 |
if target_len is not None:
|
@@ -66,6 +107,36 @@ def __pad_fn(input_ids_list, padding_value=0, target_len=None, padding_side="lef
|
|
66 |
return torch.cat(new_input_ids_list, dim=0)
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
class VILAProcessorKwargs(ProcessingKwargs, total=False):
|
70 |
_defaults = {
|
71 |
"text_kwargs": {
|
@@ -74,8 +145,6 @@ class VILAProcessorKwargs(ProcessingKwargs, total=False):
|
|
74 |
}
|
75 |
|
76 |
|
77 |
-
|
78 |
-
|
79 |
class VILAProcessor(ProcessorMixin):
|
80 |
# attributes = ["image_processor", "tokenizer"]
|
81 |
attributes = []
|
@@ -85,16 +154,83 @@ class VILAProcessor(ProcessorMixin):
|
|
85 |
# tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
|
86 |
|
87 |
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
|
88 |
-
# self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
89 |
-
# self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
90 |
self.image_token = MEDIA_TOKENS["image"]
|
91 |
self.video_token = MEDIA_TOKENS["video"]
|
92 |
self.config = config
|
93 |
self.image_processor = image_processor
|
94 |
self.tokenizer = tokenizer
|
95 |
-
|
|
|
|
|
|
|
96 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
@classmethod
|
99 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
100 |
if os.path.isdir(pretrained_model_name_or_path):
|
@@ -115,40 +251,59 @@ class VILAProcessor(ProcessorMixin):
|
|
115 |
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
|
116 |
|
117 |
def __repr__(self):
|
118 |
-
|
119 |
-
|
120 |
-
)
|
121 |
|
122 |
def __call__(
|
123 |
self,
|
124 |
-
conversation,
|
125 |
-
images: ImageInput = None,
|
126 |
-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
127 |
-
videos: VideoInput = None,
|
128 |
**kwargs: Unpack[VILAProcessorKwargs],
|
129 |
) -> BatchFeature:
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
|
|
|
134 |
media = defaultdict(list)
|
135 |
media_config = defaultdict(dict)
|
136 |
for conv in conversation:
|
137 |
-
feat = self.__single_call__(conv,
|
138 |
-
|
|
|
139 |
for name in feat.media:
|
140 |
media[name] += feat.media[name]
|
141 |
for name in feat.media_config:
|
142 |
media_config[name].update(feat.media_config[name])
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
return BatchFeature(
|
145 |
data={
|
146 |
-
|
147 |
-
"
|
148 |
-
input_ids,
|
149 |
-
padding_value=self.tokenizer.pad_token_id,
|
150 |
-
padding_side="left",
|
151 |
-
),
|
152 |
"media": media,
|
153 |
"media_config": media_config,
|
154 |
}
|
@@ -195,9 +350,18 @@ class VILAProcessor(ProcessorMixin):
|
|
195 |
]
|
196 |
else:
|
197 |
raise ValueError(f"Unsupported media type: {name}")
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
def batch_decode(self, *args, **kwargs):
|
203 |
"""
|
@@ -235,38 +399,26 @@ class VILAProcessor(ProcessorMixin):
|
|
235 |
image_processor_input_names = self.image_processor.model_input_names
|
236 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
237 |
|
238 |
-
|
239 |
-
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
|
240 |
vila_conv = []
|
241 |
for chat in conversation:
|
242 |
vila_chat = {"from": "", "value": []}
|
243 |
-
if chat["role"]
|
244 |
# user allows to input image and text
|
245 |
-
vila_chat["from"] = "human"
|
246 |
-
|
247 |
-
if content["type"] == "image":
|
248 |
-
if "path" in content:
|
249 |
-
# VILA style
|
250 |
-
vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["path"])))
|
251 |
-
elif "image" in content:
|
252 |
-
# Qwen style
|
253 |
-
vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["image"])))
|
254 |
-
else:
|
255 |
-
raise ValueError(f"Unsupported content type `image`: {content}, `image` and `path` are required")
|
256 |
-
elif content["type"] == "text":
|
257 |
-
vila_chat["value"].append(content["text"])
|
258 |
-
# NOTE(ligeng): video supports are needed here
|
259 |
-
else:
|
260 |
-
raise ValueError(f"Unsupported content type: {content['type']}")
|
261 |
elif chat["role"] == "assistant":
|
262 |
vila_chat["from"] = "gpt"
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
vila_conv.append(vila_chat)
|
267 |
|
268 |
return vila_conv
|
269 |
|
|
|
|
|
|
|
270 |
|
271 |
if __name__ == "__main__":
|
272 |
# gpt style: user, assistant
|
@@ -301,7 +453,6 @@ if __name__ == "__main__":
|
|
301 |
# print(model.config)
|
302 |
# print(model.tokenizer)
|
303 |
# print(res)
|
304 |
-
# exit(0)
|
305 |
|
306 |
processor = VILAProcessor(
|
307 |
config=model.config,
|
|
|
3 |
import os.path as osp
|
4 |
import warnings
|
5 |
from collections import defaultdict
|
6 |
+
from io import BytesIO
|
7 |
+
from typing import List, Optional, Union
|
8 |
|
9 |
+
import PIL.Image
|
10 |
+
import requests
|
11 |
import torch
|
12 |
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
|
13 |
from transformers.feature_extraction_utils import BatchFeature
|
|
|
21 |
from .mm_utils import process_image, process_images
|
22 |
from .tokenizer_utils import tokenize_conversation
|
23 |
|
24 |
+
|
25 |
+
def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
|
26 |
+
if pil_image.mode == "RGBA":
|
27 |
+
white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
|
28 |
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
29 |
+
return white_background
|
30 |
+
else:
|
31 |
+
return pil_image.convert("RGB")
|
32 |
+
|
33 |
+
|
34 |
+
def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
|
35 |
+
if "image" in ele:
|
36 |
+
image = ele["image"]
|
37 |
+
else:
|
38 |
+
image = ele["image_url"]
|
39 |
+
image_obj = None
|
40 |
+
if isinstance(image, PIL.Image.Image):
|
41 |
+
image_obj = image
|
42 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
43 |
+
response = requests.get(image, stream=True)
|
44 |
+
image_obj = PIL.Image.open(BytesIO(response.content))
|
45 |
+
elif image.startswith("file://"):
|
46 |
+
image_obj = PIL.Image.open(image[7:])
|
47 |
+
elif image.startswith("data:image"):
|
48 |
+
if "base64," in image:
|
49 |
+
_, base64_data = image.split("base64,", 1)
|
50 |
+
data = base64.b64decode(base64_data)
|
51 |
+
image_obj = PIL.Image.open(BytesIO(data))
|
52 |
+
else:
|
53 |
+
image_obj = PIL.Image.open(image)
|
54 |
+
if image_obj is None:
|
55 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
56 |
+
image = to_rgb(image_obj)
|
57 |
+
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
def fetch_image_url_or_fpath(url_or_fpath):
|
62 |
if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
|
63 |
import tempfile
|
64 |
+
|
65 |
import requests
|
66 |
+
|
67 |
# Download the image to a temporary file
|
68 |
temp_dir = tempfile.mkdtemp()
|
69 |
temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
|
70 |
+
|
71 |
response = requests.get(url_or_fpath, stream=True)
|
72 |
response.raise_for_status()
|
73 |
+
|
74 |
with open(temp_file, "wb") as f:
|
75 |
for chunk in response.iter_content(chunk_size=8192):
|
76 |
f.write(chunk)
|
77 |
+
|
78 |
return temp_file
|
79 |
elif url_or_fpath.startswith("file://"):
|
80 |
fpath = url_or_fpath.replace("file://", "")
|
81 |
assert osp.exists(fpath), f"File {fpath} does not exist"
|
82 |
return fpath
|
83 |
elif osp.exists(url_or_fpath):
|
84 |
+
assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
|
85 |
return url_or_fpath
|
86 |
else:
|
87 |
raise ValueError(f"Unsupported image path: {url_or_fpath}")
|
|
|
88 |
|
89 |
+
|
90 |
+
def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
|
91 |
# tensor shape is (batch_size, seq_len)
|
92 |
max_len = max([ids.shape[1] for ids in input_ids_list])
|
93 |
if target_len is not None:
|
|
|
107 |
return torch.cat(new_input_ids_list, dim=0)
|
108 |
|
109 |
|
110 |
+
def extract_value_from_conv(chat):
|
111 |
+
value = []
|
112 |
+
if isinstance(chat["content"], str):
|
113 |
+
# vila_chat["value"].append(chat["content"])
|
114 |
+
value.append(chat["content"])
|
115 |
+
return value
|
116 |
+
|
117 |
+
# otherwise, it's a list of content
|
118 |
+
for content in chat["content"]:
|
119 |
+
if content["type"] == "image":
|
120 |
+
if "path" in content:
|
121 |
+
# VILA style, can be either filepath or http url
|
122 |
+
value.append(Image(fetch_image_url_or_fpath(content["path"])))
|
123 |
+
elif "image" in content:
|
124 |
+
# Qwen style
|
125 |
+
value.append(Image(fetch_image_url_or_fpath(content["image"])))
|
126 |
+
elif "image_pil" in content:
|
127 |
+
# Qwen style
|
128 |
+
assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of {media_key} must be PIL.Image.Image"
|
129 |
+
value.append(content["image_pil"])
|
130 |
+
else:
|
131 |
+
raise ValueError(f"Type = `image` , but no `path` or `image` in | {content=}, {conversation=}")
|
132 |
+
elif content["type"] == "text":
|
133 |
+
value.append(content["text"])
|
134 |
+
# NOTE(ligeng): video supports are needed here
|
135 |
+
else:
|
136 |
+
raise ValueError(f"Unsupported content type: {content['type']}")
|
137 |
+
return value
|
138 |
+
|
139 |
+
|
140 |
class VILAProcessorKwargs(ProcessingKwargs, total=False):
|
141 |
_defaults = {
|
142 |
"text_kwargs": {
|
|
|
145 |
}
|
146 |
|
147 |
|
|
|
|
|
148 |
class VILAProcessor(ProcessorMixin):
|
149 |
# attributes = ["image_processor", "tokenizer"]
|
150 |
attributes = []
|
|
|
154 |
# tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
|
155 |
|
156 |
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
|
|
|
|
|
157 |
self.image_token = MEDIA_TOKENS["image"]
|
158 |
self.video_token = MEDIA_TOKENS["video"]
|
159 |
self.config = config
|
160 |
self.image_processor = image_processor
|
161 |
self.tokenizer = tokenizer
|
162 |
+
# self.pad_token_id = tokenizer.pad_token_id
|
163 |
+
self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0]
|
164 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
165 |
+
# self.pad_token_id = 151643
|
166 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
167 |
|
168 |
+
@staticmethod
|
169 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
170 |
+
"""
|
171 |
+
referernce from qwen_vl_utils
|
172 |
+
"""
|
173 |
+
vision_infos = []
|
174 |
+
if isinstance(conversations[0], dict):
|
175 |
+
conversations = [conversations]
|
176 |
+
for conversation in conversations:
|
177 |
+
for message in conversation:
|
178 |
+
if isinstance(message["content"], list):
|
179 |
+
for ele in message["content"]:
|
180 |
+
if (
|
181 |
+
"image" in ele
|
182 |
+
or "image_url" in ele
|
183 |
+
or "video" in ele
|
184 |
+
or ele["type"] in ("image", "image_url", "video")
|
185 |
+
):
|
186 |
+
vision_infos.append(ele)
|
187 |
+
return vision_infos
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def process_vision_info(
|
191 |
+
conversations: list[dict] | list[list[dict]],
|
192 |
+
return_video_kwargs: bool = False,
|
193 |
+
) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
|
194 |
+
"""
|
195 |
+
referernce from qwen_vl_utils
|
196 |
+
"""
|
197 |
+
vision_infos = extract_vision_info(conversations)
|
198 |
+
## Read images or videos
|
199 |
+
image_inputs = []
|
200 |
+
video_inputs = []
|
201 |
+
video_sample_fps_list = []
|
202 |
+
for vision_info in vision_infos:
|
203 |
+
if "image" in vision_info or "image_url" in vision_info:
|
204 |
+
image_inputs.append(fetch_image(vision_info))
|
205 |
+
elif "video" in vision_info:
|
206 |
+
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
|
207 |
+
video_sample_fps_list.append(video_sample_fps)
|
208 |
+
video_inputs.append(video_input)
|
209 |
+
else:
|
210 |
+
raise ValueError("image, image_url or video should in content.")
|
211 |
+
if len(image_inputs) == 0:
|
212 |
+
image_inputs = None
|
213 |
+
if len(video_inputs) == 0:
|
214 |
+
video_inputs = None
|
215 |
+
if return_video_kwargs:
|
216 |
+
return image_inputs, video_inputs, {"fps": video_sample_fps_list}
|
217 |
+
return image_inputs, video_inputs
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def move_data_to_device(cls, prompt_inputs):
|
221 |
+
def _move_data_to_device(item):
|
222 |
+
# wrap function grpo trainer _prepare_input
|
223 |
+
kwargs = {"device": cls.args.device}
|
224 |
+
if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
|
225 |
+
kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
|
226 |
+
return item.to(**kwargs)
|
227 |
+
|
228 |
+
prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
|
229 |
+
prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
|
230 |
+
if "image" in prompt_inputs.media:
|
231 |
+
prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
|
232 |
+
return prompt_inputs
|
233 |
+
|
234 |
@classmethod
|
235 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
236 |
if os.path.isdir(pretrained_model_name_or_path):
|
|
|
251 |
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
|
252 |
|
253 |
def __repr__(self):
|
254 |
+
# NOTE(ligeng): hard coded image_processor to avoid serialization error. Dirty fix
|
255 |
+
return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
|
|
|
256 |
|
257 |
def __call__(
|
258 |
self,
|
259 |
+
conversation=None,
|
|
|
|
|
|
|
260 |
**kwargs: Unpack[VILAProcessorKwargs],
|
261 |
) -> BatchFeature:
|
262 |
+
"""
|
263 |
+
The `conv` will be look like
|
264 |
+
[
|
265 |
+
{
|
266 |
+
'from': 'human',
|
267 |
+
'value': [
|
268 |
+
<transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
|
269 |
+
'What are the common elements in these pictures?'
|
270 |
+
]
|
271 |
+
}
|
272 |
+
]
|
273 |
+
and `conversation` will be a list of such `conv`s
|
274 |
+
"""
|
275 |
+
if kwargs.get("text", None) is not None:
|
276 |
+
conversation = kwargs.get("text")
|
277 |
+
assert conversation is not None, "`conversation` or `text` is required"
|
278 |
+
padding_side = kwargs.get("padding_side", "left")
|
279 |
|
280 |
+
input_ids_list = []
|
281 |
+
attention_mask = []
|
282 |
media = defaultdict(list)
|
283 |
media_config = defaultdict(dict)
|
284 |
for conv in conversation:
|
285 |
+
feat = self.__single_call__(conv, **kwargs)
|
286 |
+
input_ids_list.append(feat.input_ids)
|
287 |
+
attention_mask.append(feat.attention_mask)
|
288 |
for name in feat.media:
|
289 |
media[name] += feat.media[name]
|
290 |
for name in feat.media_config:
|
291 |
media_config[name].update(feat.media_config[name])
|
292 |
|
293 |
+
input_ids = pad_fn(
|
294 |
+
input_ids_list,
|
295 |
+
padding_value=self.pad_token_id,
|
296 |
+
padding_side=padding_side,
|
297 |
+
)
|
298 |
+
# ignore the pad token in the attention mask
|
299 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
300 |
+
attention_mask[input_ids == self.pad_token_id] = False
|
301 |
+
# print("[DEBUGAAA]", self.pad_token_id, self.tokenizer.pad_token_id); exit(0)
|
302 |
+
|
303 |
return BatchFeature(
|
304 |
data={
|
305 |
+
"input_ids": input_ids,
|
306 |
+
"attention_mask": attention_mask,
|
|
|
|
|
|
|
|
|
307 |
"media": media,
|
308 |
"media_config": media_config,
|
309 |
}
|
|
|
350 |
]
|
351 |
else:
|
352 |
raise ValueError(f"Unsupported media type: {name}")
|
353 |
+
|
354 |
+
inputs = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True, return_ids_only=False)
|
355 |
+
input_ids = inputs.input_ids[0].cuda().unsqueeze(0)
|
356 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
357 |
+
return BatchFeature(
|
358 |
+
data={
|
359 |
+
"input_ids": input_ids,
|
360 |
+
"attention_mask": attention_mask,
|
361 |
+
"media": media,
|
362 |
+
"media_config": media_config,
|
363 |
+
}
|
364 |
+
)
|
365 |
|
366 |
def batch_decode(self, *args, **kwargs):
|
367 |
"""
|
|
|
399 |
image_processor_input_names = self.image_processor.model_input_names
|
400 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
401 |
|
402 |
+
def convert_gpt_conv_to_vila_conv(self, conversation):
|
|
|
403 |
vila_conv = []
|
404 |
for chat in conversation:
|
405 |
vila_chat = {"from": "", "value": []}
|
406 |
+
if chat["role"] in ("user", "system"):
|
407 |
# user allows to input image and text
|
408 |
+
vila_chat["from"] = "human" if chat["role"] == "user" else "system"
|
409 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
elif chat["role"] == "assistant":
|
411 |
vila_chat["from"] = "gpt"
|
412 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
413 |
+
else:
|
414 |
+
raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
|
415 |
vila_conv.append(vila_chat)
|
416 |
|
417 |
return vila_conv
|
418 |
|
419 |
+
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
|
420 |
+
return self.convert_gpt_conv_to_vila_conv(conversation)
|
421 |
+
|
422 |
|
423 |
if __name__ == "__main__":
|
424 |
# gpt style: user, assistant
|
|
|
453 |
# print(model.config)
|
454 |
# print(model.tokenizer)
|
455 |
# print(res)
|
|
|
456 |
|
457 |
processor = VILAProcessor(
|
458 |
config=model.config,
|
builder.py
CHANGED
@@ -33,6 +33,7 @@ from transformers import (
|
|
33 |
PreTrainedModel,
|
34 |
PreTrainedTokenizer,
|
35 |
)
|
|
|
36 |
|
37 |
# from .conversation import *
|
38 |
from .conversation import SeparatorStyle, default_conversation
|
@@ -202,6 +203,9 @@ def build_llm_and_tokenizer(
|
|
202 |
fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
203 |
)
|
204 |
else:
|
|
|
|
|
|
|
205 |
llm = AutoModelForCausalLM.from_pretrained(
|
206 |
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
207 |
)
|
|
|
33 |
PreTrainedModel,
|
34 |
PreTrainedTokenizer,
|
35 |
)
|
36 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
37 |
|
38 |
# from .conversation import *
|
39 |
from .conversation import SeparatorStyle, default_conversation
|
|
|
203 |
fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
204 |
)
|
205 |
else:
|
206 |
+
if is_deepspeed_zero3_enabled():
|
207 |
+
# NOTE: found by wei, need to pop out device_map when using zero3
|
208 |
+
kwargs.pop("device_map")
|
209 |
llm = AutoModelForCausalLM.from_pretrained(
|
210 |
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
211 |
)
|
mm_utils.py
CHANGED
@@ -521,8 +521,11 @@ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False,
|
|
521 |
return new_images
|
522 |
|
523 |
|
524 |
-
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
525 |
-
|
|
|
|
|
|
|
526 |
|
527 |
|
528 |
def is_gemma_tokenizer(tokenizer):
|
|
|
521 |
return new_images
|
522 |
|
523 |
|
524 |
+
def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
|
525 |
+
if return_ids:
|
526 |
+
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
527 |
+
else:
|
528 |
+
return tokenizer(prompt, return_tensors=return_tensors)
|
529 |
|
530 |
|
531 |
def is_gemma_tokenizer(tokenizer):
|
modeling_vila.py
CHANGED
@@ -428,6 +428,12 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
428 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
429 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
432 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
433 |
# XGrammar tokenizer and grammar compiler
|
@@ -650,11 +656,9 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
650 |
name = media_tokens[input_ids[k][pos].item()]
|
651 |
input = media_embeds[name].popleft()
|
652 |
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
653 |
-
|
654 |
-
elif input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
|
655 |
end = pos + 1
|
656 |
pos = end
|
657 |
-
# print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
|
658 |
continue
|
659 |
else:
|
660 |
end = pos
|
@@ -662,7 +666,6 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
662 |
end += 1
|
663 |
input = text_embeds[k][pos:end]
|
664 |
label = labels[k][pos:end]
|
665 |
-
# print(f"[text] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
|
666 |
|
667 |
inputs_mk.append(input)
|
668 |
labels_mk.append(label)
|
@@ -1018,6 +1021,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
1018 |
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
1019 |
images: Optional[torch.FloatTensor] = None,
|
1020 |
media_config: Optional[List] = None,
|
|
|
1021 |
attention_mask: Optional[torch.Tensor] = None,
|
1022 |
position_ids: Optional[torch.LongTensor] = None,
|
1023 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
@@ -1074,21 +1078,56 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
1074 |
|
1075 |
return outputs
|
1076 |
|
1077 |
-
@torch.inference_mode()
|
1078 |
def generate(
|
1079 |
self,
|
1080 |
input_ids: Optional[torch.FloatTensor] = None,
|
1081 |
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
1082 |
media_config: Dict[str, Dict[str, Any]] = None,
|
1083 |
attention_mask: Optional[torch.LongTensor] = None,
|
|
|
1084 |
**generation_kwargs,
|
1085 |
-
):
|
|
|
1086 |
if self.training:
|
1087 |
warnings.warn(
|
1088 |
-
"Model is in training mode, using default padding strategy to right. This is not recommended for generation."
|
1089 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1090 |
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
|
1091 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1092 |
|
1093 |
@torch.inference_mode()
|
1094 |
def generate_content(
|
@@ -1101,10 +1140,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
1101 |
conversation = [{"from": "human", "value": prompt}]
|
1102 |
|
1103 |
# Convert response format to logits processor
|
1104 |
-
|
1105 |
-
xgr_logits_processor = self.get_xgr_logits_processor(response_format)
|
1106 |
-
else:
|
1107 |
-
xgr_logits_processor = None
|
1108 |
|
1109 |
# Extract media from the conversation
|
1110 |
|
@@ -1173,7 +1209,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
1173 |
raise ValueError(f"Unsupported media type: {name}")
|
1174 |
|
1175 |
# Tokenize the conversation
|
1176 |
-
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).
|
1177 |
|
1178 |
# Set up the generation config
|
1179 |
generation_config = generation_config or self.default_generation_config
|
|
|
428 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
429 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
430 |
|
431 |
+
self.pad_token_list = (
|
432 |
+
self.tokenizer.pad_token_id,
|
433 |
+
self.tokenizer.eos_token_id,
|
434 |
+
self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
|
435 |
+
)
|
436 |
+
|
437 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
438 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
439 |
# XGrammar tokenizer and grammar compiler
|
|
|
656 |
name = media_tokens[input_ids[k][pos].item()]
|
657 |
input = media_embeds[name].popleft()
|
658 |
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
659 |
+
elif input_ids[k][pos].item() in self.pad_token_list:
|
|
|
660 |
end = pos + 1
|
661 |
pos = end
|
|
|
662 |
continue
|
663 |
else:
|
664 |
end = pos
|
|
|
666 |
end += 1
|
667 |
input = text_embeds[k][pos:end]
|
668 |
label = labels[k][pos:end]
|
|
|
669 |
|
670 |
inputs_mk.append(input)
|
671 |
labels_mk.append(label)
|
|
|
1021 |
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
1022 |
images: Optional[torch.FloatTensor] = None,
|
1023 |
media_config: Optional[List] = None,
|
1024 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1025 |
attention_mask: Optional[torch.Tensor] = None,
|
1026 |
position_ids: Optional[torch.LongTensor] = None,
|
1027 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
1078 |
|
1079 |
return outputs
|
1080 |
|
1081 |
+
# @torch.inference_mode()
|
1082 |
def generate(
|
1083 |
self,
|
1084 |
input_ids: Optional[torch.FloatTensor] = None,
|
1085 |
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
1086 |
media_config: Dict[str, Dict[str, Any]] = None,
|
1087 |
attention_mask: Optional[torch.LongTensor] = None,
|
1088 |
+
return_output_ids_only: bool = False,
|
1089 |
**generation_kwargs,
|
1090 |
+
) -> torch.LongTensor:
|
1091 |
+
model_training_status = False
|
1092 |
if self.training:
|
1093 |
warnings.warn(
|
1094 |
+
"Model is in training mode, using default padding strategy to right. This is not recommended for generation. We implicitly set the model to evaluation mode and restore the model training status after generation."
|
1095 |
)
|
1096 |
+
self.eval()
|
1097 |
+
model_training_status = True
|
1098 |
+
"""
|
1099 |
+
input_tokens: <image> describe the image
|
1100 |
+
media: [Tensor(1, 3, 384, 384), ]
|
1101 |
+
----------->
|
1102 |
+
input_tokens: 36000 001 002 003 004
|
1103 |
+
input_emds: <media emd> 001 002 003 004
|
1104 |
+
"""
|
1105 |
+
# TODO: there is still a padding left vs right issue unsovled here.
|
1106 |
+
# print("prev args:",input_ids.shape, media, media_config, None, attention_mask)
|
1107 |
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
|
1108 |
+
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds.mean(), inputs_embeds.std())
|
1109 |
+
# print("attention_mask", attention_mask.shape, attention_mask)
|
1110 |
+
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
|
1111 |
+
# print("output_ids", self.tokenizer.batch_decode(output_ids))
|
1112 |
+
# input("wait for debug")
|
1113 |
+
if return_output_ids_only:
|
1114 |
+
return_value = output_ids
|
1115 |
+
else:
|
1116 |
+
# by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
|
1117 |
+
# print(f"[DEBUG REMOTE] input_ids: {input_ids.shape}, output_ids: {output_ids.shape} attention_mask: {attention_mask.shape} {generation_kwargs=}"); exit(0)
|
1118 |
+
generation_config = generation_kwargs.get("generation_config", None)
|
1119 |
+
if generation_config is not None:
|
1120 |
+
num_generations = generation_config.num_return_sequences
|
1121 |
+
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
|
1122 |
+
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
|
1123 |
+
else:
|
1124 |
+
return_value = torch.cat([input_ids, output_ids], dim=-1)
|
1125 |
+
|
1126 |
+
if model_training_status:
|
1127 |
+
# restore the model training status
|
1128 |
+
self.train()
|
1129 |
+
|
1130 |
+
return return_value
|
1131 |
|
1132 |
@torch.inference_mode()
|
1133 |
def generate_content(
|
|
|
1140 |
conversation = [{"from": "human", "value": prompt}]
|
1141 |
|
1142 |
# Convert response format to logits processor
|
1143 |
+
xgr_logits_processor = None
|
|
|
|
|
|
|
1144 |
|
1145 |
# Extract media from the conversation
|
1146 |
|
|
|
1209 |
raise ValueError(f"Unsupported media type: {name}")
|
1210 |
|
1211 |
# Tokenize the conversation
|
1212 |
+
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
|
1213 |
|
1214 |
# Set up the generation config
|
1215 |
generation_config = generation_config or self.default_generation_config
|
tokenizer_utils.py
CHANGED
@@ -68,13 +68,16 @@ def tokenize_conversation_legacy(
|
|
68 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
69 |
|
70 |
|
|
|
71 |
def tokenize_conversation(
|
72 |
messages: Sequence[Dict[str, str]],
|
73 |
tokenizer: transformers.PreTrainedTokenizer,
|
74 |
add_generation_prompt: bool = False,
|
75 |
overrides: Optional[Dict[str, str]] = None,
|
76 |
no_system_prompt: bool = False,
|
|
|
77 |
) -> torch.Tensor:
|
|
|
78 |
# Normalize the conversation before tokenization
|
79 |
for message in messages:
|
80 |
message["value"] = message["value"].strip()
|
@@ -95,6 +98,10 @@ def tokenize_conversation(
|
|
95 |
message["role"] = "user"
|
96 |
elif m["from"] == "gpt":
|
97 |
message["role"] = "assistant"
|
|
|
|
|
|
|
|
|
98 |
else:
|
99 |
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
|
100 |
|
@@ -111,7 +118,7 @@ def tokenize_conversation(
|
|
111 |
add_generation_prompt=add_generation_prompt,
|
112 |
tokenize=False,
|
113 |
)
|
114 |
-
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
115 |
|
116 |
|
117 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
|
|
68 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
69 |
|
70 |
|
71 |
+
# NOTE(ligeng): add a return typing to help code analyze
|
72 |
def tokenize_conversation(
|
73 |
messages: Sequence[Dict[str, str]],
|
74 |
tokenizer: transformers.PreTrainedTokenizer,
|
75 |
add_generation_prompt: bool = False,
|
76 |
overrides: Optional[Dict[str, str]] = None,
|
77 |
no_system_prompt: bool = False,
|
78 |
+
return_ids_only=True,
|
79 |
) -> torch.Tensor:
|
80 |
+
# print("messages", messages); input()
|
81 |
# Normalize the conversation before tokenization
|
82 |
for message in messages:
|
83 |
message["value"] = message["value"].strip()
|
|
|
98 |
message["role"] = "user"
|
99 |
elif m["from"] == "gpt":
|
100 |
message["role"] = "assistant"
|
101 |
+
elif m["from"] == "system":
|
102 |
+
message["role"] = "system"
|
103 |
+
if no_system_prompt:
|
104 |
+
raise ValueError("System prompt is not allowed when no_system_prompt is True.")
|
105 |
else:
|
106 |
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
|
107 |
|
|
|
118 |
add_generation_prompt=add_generation_prompt,
|
119 |
tokenize=False,
|
120 |
)
|
121 |
+
return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only)
|
122 |
|
123 |
|
124 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|