|
import copy |
|
import json |
|
import logging |
|
import math |
|
import os |
|
import os.path |
|
import os.path as osp |
|
import shutil |
|
import warnings |
|
from abc import ABC |
|
from collections import OrderedDict, defaultdict, deque |
|
from copy import deepcopy |
|
from itertools import chain |
|
from threading import Thread |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
from einops import rearrange |
|
from PIL import Image |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoProcessor, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
LogitsProcessor, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
Qwen2Config, |
|
Qwen2ForCausalLM, |
|
Qwen2PreTrainedModel, |
|
TextIteratorStreamer, |
|
) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.modeling_utils import ContextManagers, no_init_weights |
|
|
|
from .auto_processor import VILAProcessor |
|
from .base_projector import MultimodalProjector, MultimodalProjectorConfig |
|
from .builder import build_llm_and_tokenizer |
|
from .configuration_vila import VILAConfig |
|
from .constants import * |
|
from .conversation import SeparatorStyle, default_conversation |
|
from .distributed import all_gather as vila_all_gather |
|
from .loss import soft_cross_entropy |
|
from .media import extract_media |
|
from .media_encoder import BasicImageEncoder, BasicVideoEncoder |
|
from .mm_utils import process_image, process_images |
|
from .model_utils_packing import set_seqlens_in_batch |
|
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2 |
|
from .tokenizer_utils import tokenize_conversation |
|
from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template |
|
|
|
|
|
|
|
|
|
python_input = input |
|
|
|
|
|
def get_pg_manager(): |
|
return None |
|
|
|
|
|
def get_model_weights_dtype(model: nn.Module): |
|
pass |
|
|
|
|
|
def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
if model_type_or_path is None: |
|
return None |
|
|
|
if config.resume_path: |
|
assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!" |
|
return MultimodalProjector.from_pretrained(model_type_or_path, config) |
|
|
|
else: |
|
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) |
|
mm_projector = MultimodalProjector(mm_projector_cfg, config) |
|
return mm_projector |
|
|
|
|
|
def check_dot_in_model_path(model_path: str): |
|
"""Check if the model path contains dot, which will affect the remote code loading.""" |
|
if osp.isdir(model_path): |
|
if "." in osp.abspath(model_path): |
|
return True |
|
else: |
|
if "." in model_path: |
|
return True |
|
return False |
|
|
|
|
|
def get_vila_version(model_path: str) -> str: |
|
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"] |
|
for version in VERSIONS: |
|
if version in model_path.lower(): |
|
return version |
|
return None |
|
|
|
|
|
def generate_jinja_template(conv_mode: str) -> str: |
|
if conv_mode == "vicuna_v1": |
|
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %} |
|
{% set roles = ["user", "assistant"] %} |
|
{% set sep = " " %} |
|
|
|
{{ system_prompt }} |
|
|
|
{% for message in messages %} |
|
{% if message['role'] == roles[0] %} |
|
{{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }} |
|
{% else %} |
|
{{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }} |
|
{% endif %} |
|
{% endfor %} |
|
{% if messages[-1]['role'] == 'user' %} |
|
{{ "ASSISTANT:" }} |
|
{% endif %} |
|
""" |
|
elif conv_mode == "llama_3": |
|
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %} |
|
{% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%} |
|
{% set sep = "<|eot_id|>" %} |
|
|
|
{{ system_prompt }} |
|
{% for message in messages %} |
|
{% if message['role'] == 'user' %} |
|
{{ roles[0] }}{{ message['content'] }}{{ sep }} |
|
{% else %} |
|
{{ roles[1] }}{{ message['content'] }}{{ sep }} |
|
{% endif %} |
|
{% endfor %} |
|
{% if messages[-1]['role'] == 'user' %} |
|
{{ roles[1] }} |
|
{% endif %} |
|
""" |
|
elif conv_mode == "hermes_2": |
|
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %} |
|
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %} |
|
{% set sep = "<|im_end|>" %} |
|
|
|
{{ system_prompt }}{{ sep }} |
|
|
|
{% for message in messages %} |
|
{% if message['role'] == 'user' %} |
|
{{ roles[0] }}{{ message['content'] }}{{ sep }} |
|
{% else %} |
|
{{ roles[1] }}{{ message['content'] }}{{ sep }} |
|
{% endif %} |
|
{% endfor %}""" |
|
else: |
|
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.") |
|
|
|
|
|
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
|
|
if model_name_or_path is None: |
|
return None |
|
|
|
vision_tower_arch = None |
|
if config.resume_path and "radio" not in model_name_or_path: |
|
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!" |
|
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
vision_tower_arch = vision_tower_cfg.architectures[0].lower() |
|
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path |
|
|
|
use_s2 = getattr(config, "s2", False) |
|
use_dynamic_s2 = getattr(config, "dynamic_s2", False) |
|
|
|
if "siglip" in vision_tower_name: |
|
if use_dynamic_s2: |
|
vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config) |
|
elif use_s2: |
|
vision_tower = SiglipVisionTowerS2(model_name_or_path, config) |
|
else: |
|
vision_tower = SiglipVisionTower(model_name_or_path, config) |
|
else: |
|
raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}") |
|
|
|
config.mm_hidden_size = ( |
|
vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size |
|
) |
|
return vision_tower |
|
|
|
|
|
class VILAPretrainedModel(PreTrainedModel): |
|
config_class = VILAConfig |
|
main_input_name = "input_embeds" |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
|
|
def __init__(self, config: VILAConfig, *args, **kwargs): |
|
super().__init__(config) |
|
self.config = config |
|
cfgs = get_model_config(config) |
|
if len(cfgs) == 3: |
|
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs |
|
else: |
|
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") |
|
|
|
|
|
device_map = kwargs.get("device_map", "cpu") |
|
self.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
self.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
if "auto" in device_map or "cuda" in device_map: |
|
self.mm_projector = self.mm_projector.cuda() |
|
self.vision_tower = self.vision_tower.cuda() |
|
|
|
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map) |
|
|
|
|
|
self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)} |
|
|
|
self.post_config() |
|
self.is_loaded = True |
|
|
|
assert ( |
|
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None |
|
), "At least one of the components must be instantiated." |
|
|
|
@classmethod |
|
def convert_vila_dev_ckpt_to_remote( |
|
self, |
|
model_path: str, |
|
output_dir: str = None, |
|
vila_version: str | None = None, |
|
conv_mode: str | None = None, |
|
copy: bool = False, |
|
copy_weights: bool = True, |
|
copy_code: bool = True, |
|
*model_args, |
|
**kwargs, |
|
): |
|
|
|
assert model_path != output_dir, "model_path and output_dir cannot be the same" |
|
if os.path.isdir(model_path): |
|
model_path = model_path |
|
else: |
|
from huggingface_hub import HfApi, snapshot_download |
|
|
|
model_path = snapshot_download(model_path) |
|
print("downloading HF model to", model_path) |
|
|
|
if check_dot_in_model_path(model_path) and output_dir is None: |
|
raise ValueError( |
|
f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue." |
|
) |
|
if output_dir is not None and "." in output_dir: |
|
raise ValueError( |
|
f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots." |
|
) |
|
|
|
if copy: |
|
print("copy is set to True, copying weights and code to output_dir") |
|
copy_weights = copy_code = True |
|
|
|
self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights) |
|
self.copy_remote_py_files(output_dir, copy=copy_code) |
|
|
|
if vila_version is None: |
|
vila_version = get_vila_version(output_dir) |
|
|
|
cfg_path = os.path.join(output_dir, "config.json") |
|
config = json.load(open(cfg_path)) |
|
config["version"] = "2.0" |
|
config["architectures"] = ["VILAForCasualLM"] |
|
config["auto_map"] = { |
|
"AutoProcessor": "auto_processor.VILAProcessor", |
|
"AutoConfig": "modeling_vila.VILAConfig", |
|
"AutoModel": "modeling_vila.VILAForCasualLM", |
|
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM", |
|
} |
|
|
|
config["model_type"] = "vila" |
|
if vila_version in ["vila1.5", "vila-m3"]: |
|
if conv_mode is None: |
|
raise ValueError(f"Please specify the conversation mode for {output_dir}.") |
|
config["chat_template"] = conv_mode |
|
jinja_template = generate_jinja_template(conv_mode) |
|
jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja") |
|
with open(jinja_path, "w") as f: |
|
f.write(jinja_template) |
|
json.dump(config, open(cfg_path, "w"), indent=2) |
|
|
|
|
|
config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True) |
|
tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config) |
|
tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
|
|
|
|
|
@classmethod |
|
def copy_or_symlink_directory(cls, model_path, output_dir, copy=True): |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
for item in os.listdir(model_path): |
|
src_path = os.path.join(model_path, item) |
|
dst_path = os.path.join(output_dir, item) |
|
|
|
|
|
if os.path.exists(dst_path): |
|
if os.path.islink(dst_path): |
|
os.unlink(dst_path) |
|
elif os.path.isdir(dst_path): |
|
shutil.rmtree(dst_path) |
|
else: |
|
os.remove(dst_path) |
|
|
|
|
|
if copy: |
|
if os.path.isdir(src_path): |
|
shutil.copytree(src_path, dst_path) |
|
else: |
|
shutil.copy2(src_path, dst_path) |
|
print(f"Copied {src_path} to {dst_path}") |
|
else: |
|
os.symlink(src_path, dst_path) |
|
print(f"Created symlink from {src_path} to {dst_path}") |
|
|
|
@classmethod |
|
def copy_remote_py_files(cls, output_dir, copy=True): |
|
|
|
current_file_path = os.path.abspath(__file__) |
|
current_folder = os.path.dirname(current_file_path) |
|
for file_name in os.listdir(current_folder): |
|
if file_name == "INSTRUCTIONS.md": |
|
src_fname = os.path.join(current_folder, file_name) |
|
dst_fname = os.path.join(output_dir, "README.md") |
|
if os.path.exists(dst_fname): |
|
old_reamde = open(dst_fname).read() |
|
else: |
|
old_reamde = "" |
|
with open(src_fname) as src, open(dst_fname, "w") as dst: |
|
dst.write(src.read()) |
|
dst.write(old_reamde) |
|
print("[HF remote code] REAMDE ", src_fname, "to", dst_fname) |
|
if file_name.endswith(".py") or file_name.endswith(".jinja"): |
|
full_file_name = os.path.join(current_folder, file_name) |
|
if os.path.isfile(full_file_name): |
|
if copy: |
|
shutil.copy(full_file_name, output_dir) |
|
print("[HF remote code] copying", full_file_name, "to", output_dir) |
|
else: |
|
|
|
if os.path.exists(os.path.join(output_dir, file_name)): |
|
os.remove(os.path.join(output_dir, file_name)) |
|
os.symlink(full_file_name, os.path.join(output_dir, file_name)) |
|
print("[HF remote code] linking", full_file_name, "to", output_dir) |
|
|
|
def save_pretrained(self, output_dir, state_dict=None, **kwargs): |
|
if state_dict is None: |
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
if getattr(self, "tokenizer", None): |
|
self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
|
|
|
if self.get_llm(): |
|
print(f"saving llm to {osp.join(output_dir, 'llm')}") |
|
self.llm.config._name_or_path = osp.join(output_dir, "llm") |
|
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) |
|
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) |
|
self.config.llm_cfg = self.llm.config |
|
|
|
if self.get_vision_tower(): |
|
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") |
|
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower") |
|
vision_tower_state_dict = OrderedDict( |
|
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k} |
|
) |
|
self.vision_tower.vision_tower.save_pretrained( |
|
os.path.join(output_dir, "vision_tower"), |
|
state_dict=vision_tower_state_dict, |
|
) |
|
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower")) |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if hasattr(self.config.vision_tower_cfg, "auto_map"): |
|
if "radio" not in self.get_vision_tower().__class__.__name__.lower(): |
|
delattr(self.config.vision_tower_cfg, "auto_map") |
|
|
|
if self.get_mm_projector(): |
|
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") |
|
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector") |
|
mm_projector_state_dict = OrderedDict( |
|
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k} |
|
) |
|
self.mm_projector.save_pretrained( |
|
os.path.join(output_dir, "mm_projector"), |
|
state_dict=mm_projector_state_dict, |
|
) |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
|
|
self.config._name_or_path = output_dir |
|
self.config.architectures = [self.__class__.__name__] |
|
self.config.save_pretrained(output_dir) |
|
|
|
|
|
self.copy_remote_py_files(output_dir) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[str] = None, |
|
*model_args, |
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
ignore_mismatched_sizes: bool = False, |
|
force_download: bool = False, |
|
local_files_only: bool = False, |
|
token: Optional[Union[str, bool]] = None, |
|
revision: str = "main", |
|
use_safetensors: Optional[bool] = None, |
|
weights_only: bool = True, |
|
**kwargs, |
|
): |
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
return cls._from_config(config, **kwargs) |
|
|
|
def init_llm(self, llm_config, config, *args, **kwargs): |
|
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs) |
|
|
|
|
|
|
|
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys()) |
|
|
|
self.pad_token_list = ( |
|
self.tokenizer.pad_token_id, |
|
self.tokenizer.eos_token_id, |
|
self.tokenizer.tokenize("<|endoftext|>")[0], |
|
) |
|
|
|
|
|
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS |
|
|
|
|
|
self.grammar_compiler = None |
|
self.llm.resize_token_embeddings(len(self.tokenizer)) |
|
return self.llm, self.tokenizer |
|
|
|
def post_config(self): |
|
|
|
|
|
self.llm = self.llm.to(torch.float16) |
|
self.mm_projector = self.mm_projector.to(torch.float16) |
|
self.vision_tower = self.vision_tower.to(torch.float16) |
|
|
|
self.training = self.llm.training |
|
|
|
if getattr(self.config, "llm_cfg", None) is None: |
|
self.config.llm_cfg = self.llm.config |
|
if getattr(self.config, "vision_tower_cfg", None) is None: |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if getattr(self.config, "mm_projector_cfg", None) is None: |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
def get_llm(self): |
|
llm = getattr(self, "llm", None) |
|
if type(llm) is list: |
|
llm = llm[0] |
|
return llm |
|
|
|
def get_lm_head(self): |
|
lm_head = getattr(self.get_llm(), "lm_head", None) |
|
return lm_head |
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, "vision_tower", None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
def get_mm_projector(self): |
|
mm_projector = getattr(self, "mm_projector", None) |
|
if type(mm_projector) is list: |
|
mm_projector = mm_projector[0] |
|
return mm_projector |
|
|
|
def freezed_module_patch(self): |
|
""" |
|
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. |
|
""" |
|
if self.training: |
|
if self.get_llm() and not getattr(self.config, "tune_language_model", False): |
|
pass |
|
|
|
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False): |
|
self.get_vision_tower().eval() |
|
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False): |
|
self.get_mm_projector().eval() |
|
|
|
|
|
class VILAForCasualLM(VILAPretrainedModel): |
|
def __init__(self, config: VILAConfig, *args, **kwargs): |
|
super().__init__(config, *args, **kwargs) |
|
|
|
def merge_features_for_dynamic_s2(self, image_features, block_sizes): |
|
scales = self.get_vision_tower().scales |
|
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx |
|
|
|
image_features_each_image = [] |
|
new_block_sizes = [] |
|
block_cnt = 0 |
|
for block_size_each_image in block_sizes: |
|
if block_size_each_image is None: |
|
cur_features = image_features[block_cnt : block_cnt + 1] |
|
cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5)) |
|
cur_features = cur_features.repeat(1, len(scales), 1, 1) |
|
image_features_each_image.append(cur_features) |
|
new_block_sizes.append((1, 1)) |
|
block_cnt += 1 |
|
else: |
|
cur_features_each_scale = [] |
|
for scale in scales[:-1]: |
|
num_blocks_this_scale = (scale // scales[0]) ** 2 |
|
cur_features_each_scale.append( |
|
self.merge_chessboard( |
|
image_features[block_cnt : block_cnt + num_blocks_this_scale], |
|
num_split_h=scale // scales[0], |
|
num_split_w=scale // scales[0], |
|
) |
|
) |
|
block_cnt += num_blocks_this_scale |
|
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] |
|
cur_features_each_scale.append( |
|
self.merge_chessboard( |
|
image_features[block_cnt : block_cnt + num_blocks_last_scale], |
|
num_split_h=block_size_each_image[0], |
|
num_split_w=block_size_each_image[1], |
|
) |
|
) |
|
block_cnt += num_blocks_last_scale |
|
|
|
|
|
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] |
|
cur_features = torch.cat( |
|
[ |
|
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to( |
|
cur_features_each_scale[i].dtype |
|
) |
|
for i in range(len(cur_features_each_scale)) |
|
], |
|
dim=1, |
|
) |
|
|
|
|
|
image_features_each_image.append(cur_features) |
|
|
|
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1: |
|
new_block_sizes.append(block_size_each_image) |
|
else: |
|
new_block_sizes.append( |
|
( |
|
scales[resize_output_to_scale_idx] // scales[0], |
|
scales[resize_output_to_scale_idx] // scales[0], |
|
) |
|
) |
|
|
|
assert block_cnt == len(image_features) |
|
|
|
return image_features_each_image, new_block_sizes |
|
|
|
def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None): |
|
if block_sizes is None: |
|
block_sizes = [None] * len(images) |
|
if getattr(self.config, "dynamic_s2", False): |
|
image_features = self.get_vision_tower()(images) |
|
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes) |
|
|
|
image_features = [ |
|
self.split_chessboard(x, block_size[0], block_size[1]) |
|
for x, block_size in zip(image_features, new_block_sizes) |
|
] |
|
image_features = torch.cat( |
|
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0 |
|
) |
|
image_features = self.get_mm_projector()(image_features) |
|
image_features = list( |
|
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0) |
|
) |
|
image_features = [ |
|
self.merge_chessboard(x, block_size[0], block_size[1]) |
|
for x, block_size in zip(image_features, new_block_sizes) |
|
] |
|
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] |
|
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]): |
|
image_features = torch.stack(image_features, dim=0) |
|
else: |
|
image_features = self.get_vision_tower()(images) |
|
image_features = self.get_mm_projector()(image_features) |
|
return image_features |
|
|
|
def train(self, mode: bool = True): |
|
if mode: |
|
self.tokenizer.padding_side = "right" |
|
else: |
|
self.tokenizer.padding_side = "left" |
|
super().train(mode) |
|
return self |
|
|
|
def _embed( |
|
self, |
|
input_ids: torch.Tensor, |
|
media: Dict[str, List[torch.Tensor]], |
|
media_config: Dict[str, Dict[str, Any]], |
|
labels: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
media = copy.deepcopy(media) |
|
media_config = copy.deepcopy(media_config) |
|
|
|
labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) |
|
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) |
|
|
|
PROCESS_GROUP_MANAGER = get_pg_manager() |
|
if PROCESS_GROUP_MANAGER is not None: |
|
for name in media: |
|
self.encoders[name].end_tokens = None |
|
|
|
|
|
text_embeds = self.llm.model.embed_tokens(input_ids) |
|
if media is not None: |
|
media_embeds = self.__embed_media_tokens(media, media_config) |
|
else: |
|
|
|
media_embeds = {} |
|
|
|
|
|
while media_embeds.get("dummy"): |
|
dummy_embed = media_embeds["dummy"].popleft() |
|
text_embeds += torch.sum(dummy_embed) * 0 |
|
|
|
|
|
batch_size = labels.shape[0] |
|
text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)] |
|
labels = [labels[k][attention_mask[k]] for k in range(batch_size)] |
|
|
|
|
|
media_tokens = {} |
|
for name, token_id in self.tokenizer.media_token_ids.items(): |
|
media_tokens[token_id] = name |
|
|
|
|
|
inputs_m, labels_m = [], [] |
|
for k in range(batch_size): |
|
inputs_mk, labels_mk = [], [] |
|
pos = 0 |
|
while pos < len(labels[k]): |
|
if input_ids[k][pos].item() in media_tokens: |
|
end = pos + 1 |
|
name = media_tokens[input_ids[k][pos].item()] |
|
input = media_embeds[name].popleft() |
|
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype) |
|
elif input_ids[k][pos].item() in self.pad_token_list: |
|
end = pos + 1 |
|
pos = end |
|
continue |
|
else: |
|
end = pos |
|
while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens: |
|
end += 1 |
|
input = text_embeds[k][pos:end] |
|
label = labels[k][pos:end] |
|
|
|
inputs_mk.append(input) |
|
labels_mk.append(label) |
|
pos = end |
|
inputs_m.append(torch.cat(inputs_mk, dim=0)) |
|
labels_m.append(torch.cat(labels_mk, dim=0)) |
|
inputs, labels = inputs_m, labels_m |
|
|
|
|
|
for name in media_embeds: |
|
if media_embeds[name]: |
|
raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.") |
|
|
|
|
|
inputs, labels = self.__truncate_sequence(inputs, labels) |
|
|
|
|
|
return self.__batchify_sequence(inputs, labels) |
|
|
|
def __embed_media_tokens( |
|
self, |
|
media: Dict[str, List[torch.Tensor]], |
|
media_config: Dict[str, Dict[str, Any]], |
|
) -> Dict[str, List[torch.Tensor]]: |
|
embeds = defaultdict(deque) |
|
for name in media: |
|
if self.training: |
|
|
|
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])] |
|
infos = list(chain(vila_all_gather(info))) |
|
|
|
|
|
if not infos: |
|
continue |
|
|
|
|
|
if media.get(name) is None or len(media[name]) == 0: |
|
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device) |
|
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name])) |
|
continue |
|
embeds[name] = deque(self.encoders[name](media[name], media_config[name])) |
|
return embeds |
|
|
|
def __truncate_sequence( |
|
self, inputs: List[torch.Tensor], labels: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs): |
|
warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).") |
|
inputs = [input[: self.tokenizer.model_max_length] for input in inputs] |
|
labels = [label[: self.tokenizer.model_max_length] for label in labels] |
|
return inputs, labels |
|
|
|
def __batchify_sequence( |
|
self, inputs: List[torch.Tensor], labels: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
batch_size = len(inputs) |
|
device = inputs[0].device |
|
hidden_size = inputs[0].shape[1] |
|
max_length = max(inputs[k].shape[0] for k in range(batch_size)) |
|
attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device) |
|
|
|
inputs_p, labels_p = [], [] |
|
for k in range(batch_size): |
|
size_pk = max_length - inputs[k].shape[0] |
|
inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device) |
|
labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device) |
|
if self.tokenizer.padding_side == "right": |
|
attention_mask[k, inputs[k].shape[0] :] = False |
|
inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0) |
|
labels_pk = torch.cat([labels[k], labels_pk], dim=0) |
|
else: |
|
attention_mask[k, : -inputs[k].shape[0]] = False |
|
inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0) |
|
labels_pk = torch.cat([labels_pk, labels[k]], dim=0) |
|
inputs_p.append(inputs_pk) |
|
labels_p.append(labels_pk) |
|
|
|
inputs = torch.stack(inputs_p, dim=0) |
|
labels = torch.stack(labels_p, dim=0) |
|
return inputs, labels, attention_mask |
|
|
|
def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels): |
|
|
|
PROCESS_GROUP_MANAGER = get_pg_manager() |
|
|
|
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
sp_degree = PROCESS_GROUP_MANAGER.sp_degree |
|
sp_rank = PROCESS_GROUP_MANAGER.sp_rank |
|
sp_group = PROCESS_GROUP_MANAGER.sp_pg |
|
ring_degree = PROCESS_GROUP_MANAGER.ring_degree |
|
ring_rank = PROCESS_GROUP_MANAGER.ring_rank |
|
ring_type = PROCESS_GROUP_MANAGER.ring_type |
|
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree |
|
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank |
|
|
|
bs, shard_seqlen = position_ids.shape |
|
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)] |
|
dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group) |
|
sp_seq_len_cat = torch.cat(sp_seq_len, dim=0) |
|
|
|
if sp_rank == 0: |
|
original_start_id = 0 |
|
else: |
|
original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item() |
|
original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item() |
|
|
|
|
|
all_inputs_embeds = torch.zeros( |
|
bs, |
|
torch.sum(sp_seq_len_cat), |
|
inputs_embeds.shape[-1], |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
).contiguous() |
|
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds |
|
dist.barrier(group=sp_group) |
|
dist.all_reduce(all_inputs_embeds, group=sp_group) |
|
dist.barrier(group=sp_group) |
|
|
|
attention_mask_list = [ |
|
torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device) |
|
for i in range(sp_degree) |
|
] |
|
position_ids_list = [ |
|
torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device) |
|
for i in range(sp_degree) |
|
] |
|
labels_list = [ |
|
torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree) |
|
] |
|
|
|
dist.all_gather(attention_mask_list, attention_mask, group=sp_group) |
|
dist.all_gather(position_ids_list, position_ids, group=sp_group) |
|
dist.all_gather(labels_list, labels, group=sp_group) |
|
|
|
effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)] |
|
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1) |
|
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0) |
|
|
|
global_attention_mask_list = [] |
|
global_position_ids_list = [] |
|
global_labels_list = [] |
|
global_inputs_embeds_list = [] |
|
for i in range(bs): |
|
global_attention_mask_batch_list = [] |
|
global_position_ids_batch_list = [] |
|
global_labels_batch_list = [] |
|
global_inputs_embeds_batch_list = [] |
|
for j in range(sp_degree): |
|
eff_len = effective_seqlen_batch_list[i][j] |
|
prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0 |
|
|
|
global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len]) |
|
global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len]) |
|
global_labels_batch_list.append(labels_list[j][i, :eff_len]) |
|
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :]) |
|
global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0)) |
|
global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0)) |
|
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0)) |
|
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0)) |
|
|
|
global_attention_mask = torch.nn.utils.rnn.pad_sequence( |
|
global_attention_mask_list, batch_first=True, padding_value=False |
|
) |
|
global_position_ids = torch.nn.utils.rnn.pad_sequence( |
|
global_position_ids_list, batch_first=True, padding_value=-1 |
|
) |
|
global_labels = torch.nn.utils.rnn.pad_sequence( |
|
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX |
|
) |
|
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
global_inputs_embeds_list, batch_first=True, padding_value=0 |
|
) |
|
|
|
|
|
if ring_degree > 1: |
|
total_effective_seqlen = torch.sum(effective_seqlen, dim=1) |
|
new_seqlen_per_rank = total_effective_seqlen // sp_degree |
|
assert torch.all( |
|
total_effective_seqlen % sp_degree == 0 |
|
), "total_effective_seqlen must be divisible by sp_degree" |
|
|
|
max_new_seqlen = torch.max(new_seqlen_per_rank).item() |
|
|
|
new_attention_mask = torch.zeros( |
|
(bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device |
|
) |
|
new_position_ids = torch.zeros( |
|
(bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device |
|
) |
|
new_labels = torch.full( |
|
(bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device |
|
) |
|
new_inputs_embeds = torch.zeros( |
|
(bs, max_new_seqlen, global_inputs_embeds.shape[-1]), |
|
dtype=global_inputs_embeds.dtype, |
|
device=global_inputs_embeds.device, |
|
) |
|
|
|
if ring_type == "ring_varlen": |
|
for i in range(bs): |
|
start_idx = new_seqlen_per_rank[i] * sp_rank |
|
end_idx = start_idx + new_seqlen_per_rank[i] |
|
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx] |
|
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx] |
|
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx] |
|
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[ |
|
i, start_idx:end_idx, : |
|
] |
|
elif ring_type == "zigzag_ring_varlen": |
|
chunk_size = total_effective_seqlen // (2 * sp_degree) |
|
for i in range(bs): |
|
|
|
if sp_degree == ring_degree: |
|
forward_rank_idx = sp_rank |
|
backward_rank_idx = 2 * sp_degree - sp_rank - 1 |
|
else: |
|
ulysses_offset = ulysses_rank * ring_degree * 2 |
|
forward_rank_idx = ring_rank + ulysses_offset |
|
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset |
|
|
|
|
|
start_idx_fwd = forward_rank_idx * chunk_size[i] |
|
end_idx_fwd = start_idx_fwd + chunk_size[i] |
|
|
|
start_idx_bwd = backward_rank_idx * chunk_size[i] |
|
end_idx_bwd = start_idx_bwd + chunk_size[i] |
|
|
|
|
|
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd] |
|
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[ |
|
i, start_idx_bwd:end_idx_bwd |
|
] |
|
|
|
new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd] |
|
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[ |
|
i, start_idx_bwd:end_idx_bwd |
|
] |
|
|
|
new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd] |
|
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd] |
|
|
|
new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :] |
|
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[ |
|
i, start_idx_bwd:end_idx_bwd, : |
|
] |
|
else: |
|
raise ValueError(f"Invalid ring_type: {ring_type}") |
|
else: |
|
global_seq_len = global_attention_mask.shape[-1] |
|
seq_len_sharded = global_seq_len // sp_degree |
|
start_idx_reshard = seq_len_sharded * sp_rank |
|
end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len |
|
|
|
new_attention_mask = torch.narrow( |
|
global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
) |
|
new_position_ids = torch.narrow( |
|
global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
) |
|
new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard) |
|
new_inputs_embeds = torch.narrow( |
|
global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
) |
|
|
|
return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels |
|
|
|
device = inputs_embeds.device |
|
batch_size = inputs_embeds.shape[0] |
|
seqlens = [attention_mask[k].sum().item() for k in range(batch_size)] |
|
|
|
|
|
inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)] |
|
attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] |
|
position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] |
|
labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)] |
|
|
|
|
|
inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device)) |
|
attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device)) |
|
position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device)) |
|
labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device)) |
|
|
|
|
|
for label in labels_p: |
|
label[0] = IGNORE_INDEX |
|
|
|
|
|
inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0) |
|
attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0) |
|
position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0) |
|
labels_p = torch.cat(labels_p, dim=0).unsqueeze(0) |
|
|
|
if hasattr( |
|
self, "pad_to_multiple_of" |
|
): |
|
assert len(labels_p.shape) == 2 |
|
batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1] |
|
hidden_size = inputs_embeds_p.shape[-1] |
|
|
|
if max_length % self.pad_to_multiple_of != 0: |
|
max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of |
|
difference = max_length - cur_length |
|
|
|
inputs_embeds_p = torch.cat( |
|
( |
|
inputs_embeds_p, |
|
torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p), |
|
), |
|
dim=1, |
|
) |
|
labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1) |
|
attention_mask_p = torch.cat( |
|
( |
|
attention_mask_p, |
|
torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p), |
|
), |
|
dim=1, |
|
) |
|
position_ids_p = torch.cat( |
|
(position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1 |
|
) |
|
|
|
return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p |
|
|
|
def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]: |
|
raise NotImplementedError("This method is not implemented for VILA model.") |
|
|
|
import xgrammar as xgr |
|
|
|
logging.info("[XGrammar] Compiling grammar for contrained output") |
|
|
|
if self.grammar_compiler is None: |
|
|
|
self.grammar_compiler = xgr.GrammarCompiler( |
|
xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size) |
|
) |
|
|
|
if response_format.type == "json_schema": |
|
compiled_grammar = self.grammar_compiler.compile_json_schema( |
|
response_format.json_schema.schema_, |
|
indent=2, |
|
) |
|
else: |
|
compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar() |
|
|
|
return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)] |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
media: Optional[Dict[str, List[torch.Tensor]]] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
media_config: Optional[List] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
packing: bool = True, |
|
force_packing: bool = False, |
|
seqlens_in_batch: Optional[torch.LongTensor] = None, |
|
dpo_forward: bool = False, |
|
**kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
self.freezed_module_patch() |
|
|
|
if images is not None: |
|
if media is not None: |
|
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.") |
|
print("The 'images' argument is deprecated. Please use 'media' instead.") |
|
media = {"image": images} |
|
|
|
if media_config is None: |
|
media_config = defaultdict(dict) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask) |
|
|
|
if force_packing or (packing and self.training and not dpo_forward): |
|
if seqlens_in_batch is None: |
|
seqlens_in_batch = torch.sum(attention_mask, dim=1) |
|
set_seqlens_in_batch(seqlens_in_batch) |
|
|
|
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data( |
|
inputs_embeds, attention_mask, position_ids, labels |
|
) |
|
|
|
outputs = self.llm( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
labels=labels, |
|
**kwargs, |
|
) |
|
|
|
if self.training and getattr(self.config, "time_token_ids", []): |
|
outputs.loss = soft_cross_entropy( |
|
outputs.logits, |
|
labels, |
|
soft_tokens=self.config.time_token_ids, |
|
std=self.config.soft_ce_std, |
|
) |
|
|
|
if dpo_forward: |
|
return outputs.logits, labels |
|
|
|
return outputs |
|
|
|
|
|
def generate( |
|
self, |
|
input_ids: Optional[torch.FloatTensor] = None, |
|
media: Optional[Dict[str, List[torch.Tensor]]] = None, |
|
media_config: Dict[str, Dict[str, Any]] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
return_output_ids_only: bool = False, |
|
**generation_kwargs, |
|
) -> torch.LongTensor: |
|
model_training_status = False |
|
if self.training: |
|
warnings.warn( |
|
"Model is in training mode, using default padding strategy to right. This is not recommended for generation. We implicitly set the model to evaluation mode and restore the model training status after generation." |
|
) |
|
self.eval() |
|
model_training_status = True |
|
""" |
|
input_tokens: <image> describe the image |
|
media: [Tensor(1, 3, 384, 384), ] |
|
-----------> |
|
input_tokens: 36000 001 002 003 004 |
|
input_emds: <media emd> 001 002 003 004 |
|
""" |
|
|
|
|
|
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask) |
|
|
|
|
|
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs) |
|
|
|
|
|
if return_output_ids_only: |
|
return_value = output_ids |
|
else: |
|
|
|
|
|
generation_config = generation_kwargs.get("generation_config", None) |
|
if generation_config is not None: |
|
num_generations = generation_config.num_return_sequences |
|
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0) |
|
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1) |
|
else: |
|
return_value = torch.cat([input_ids, output_ids], dim=-1) |
|
|
|
if model_training_status: |
|
|
|
self.train() |
|
|
|
return return_value |
|
|
|
@torch.inference_mode() |
|
def generate_content( |
|
self, |
|
prompt: Union[str, List], |
|
generation_config: Optional[GenerationConfig] = None, |
|
response_format=None, |
|
) -> str: |
|
|
|
conversation = [{"from": "human", "value": prompt}] |
|
|
|
|
|
xgr_logits_processor = None |
|
|
|
|
|
|
|
|
|
media = extract_media(conversation, self.config) |
|
|
|
|
|
media_config = defaultdict(dict) |
|
for name in media: |
|
if name == "image": |
|
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: |
|
self.config.image_processor = self.vision_tower.image_processor |
|
if self.config.image_aspect_ratio == "dynamic": |
|
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() |
|
conversation[0]["value"] = conversation[0]["value"].replace( |
|
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] |
|
) |
|
else: |
|
if type(self.config.s2_scales) is str: |
|
self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) |
|
images, block_sizes = process_image( |
|
media["image"][0], self.config, None, enable_dynamic_s2=True |
|
) |
|
images = images.half() |
|
media_config[name]["block_sizes"] = [block_sizes] |
|
else: |
|
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() |
|
media[name] = [image for image in images] |
|
elif name == "video": |
|
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1: |
|
media[name] = [ |
|
process_images( |
|
images, |
|
self.vision_tower.image_processor, |
|
self.config, |
|
enable_dynamic_res=True, |
|
max_tiles=self.config.video_max_tiles, |
|
).half() |
|
for images in media[name] |
|
] |
|
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1: |
|
self.config.image_processor = self.vision_tower.image_processor |
|
if type(self.config.s2_scales) is str: |
|
self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) |
|
media[name] = [ |
|
torch.cat( |
|
[ |
|
process_image( |
|
image, |
|
self.config, |
|
None, |
|
enable_dynamic_s2=True, |
|
max_tiles=self.config.video_max_tiles, |
|
)[0].half() |
|
for image in images |
|
] |
|
) |
|
for images in media[name] |
|
] |
|
else: |
|
media[name] = [ |
|
process_images(images, self.vision_tower.image_processor, self.config).half() |
|
for images in media[name] |
|
] |
|
else: |
|
raise ValueError(f"Unsupported media type: {name}") |
|
|
|
|
|
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda() |
|
|
|
|
|
generation_config = generation_config or self.default_generation_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
output_ids = self.generate( |
|
input_ids=input_ids, |
|
media=media, |
|
media_config=media_config, |
|
generation_config=generation_config, |
|
logits_processor=xgr_logits_processor, |
|
) |
|
except ValueError: |
|
if not generation_config.do_sample: |
|
raise |
|
|
|
logging.warning("Generation failed with sampling, retrying with greedy decoding.") |
|
generation_config.do_sample = False |
|
output_ids = self.generate( |
|
input_ids=input_ids, |
|
media=media, |
|
media_config=media_config, |
|
generation_config=generation_config, |
|
logits_processor=xgr_logits_processor, |
|
) |
|
|
|
|
|
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
return response |
|
|
|
@property |
|
def default_generation_config(self) -> GenerationConfig: |
|
generation_config = copy.deepcopy(self.generation_config or GenerationConfig()) |
|
if self.tokenizer.eos_token_id is None: |
|
raise ValueError("Tokenizer must have an EOS token") |
|
if generation_config.max_length == GenerationConfig().max_length: |
|
generation_config.max_length = self.tokenizer.model_max_length |
|
if generation_config.pad_token_id is None: |
|
generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id |
|
if generation_config.bos_token_id is None: |
|
generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id |
|
if generation_config.eos_token_id is None: |
|
generation_config.eos_token_id = self.tokenizer.eos_token_id |
|
return generation_config |
|
|