M4
Collection
M4 is a collection of omni multimodal multiplexing model
•
7 items
•
Updated
LongVA-7B-Qwen2-Audio is an extension of LongVA-7B, further trained using the LLaVA-NeXT-Audio dataset for 0.4 epochs.
Please refer to M4 to install relvevant packages
import os
from PIL import Image
import numpy as np
import torchaudio
import torch
from decord import VideoReader, cpu
import whisper
# fix seed
torch.manual_seed(0)
from intersuit.model.builder import load_pretrained_model
from intersuit.mm_utils import tokenizer_image_speech_tokens, process_images
from intersuit.constants import IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX
import warnings
warnings.filterwarnings("ignore")
model_path = "ColorfulAI/LongVA-7B-Qwen2-Audio"
video_path = "local_demo/assets/water.mp4"
audio_path = "local_demo/wav/infer.wav"
max_frames_num = 16 # you can change this to several thousands so long you GPU memory can handle it :)
gen_kwargs = {"do_sample": True, "temperature": 0.5, "top_p": None, "num_beams": 1, "use_cache": True, "max_new_tokens": 1024}
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "llava_qwen", device_map="cuda:0")
query = "Give a detailed caption of the video as if I am blind."
query = None # comment this to use ChatTTS to convert the query to audio
#video input
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image><|im_end|>\n<|im_start|>user\n<speech>\n<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer_image_speech_tokens(prompt, tokenizer, IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
video_tensor = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16)
#audio input
# process speech for input question
if query is not None:
import ChatTTS
chat = ChatTTS.Chat()
chat.load(source='local', compile=True)
audio_path = "./local_demo/wav/" + "infer.wav"
if os.path.exists(audio_path): os.remove(audio_path) # refresh
if not os.path.exists(audio_path):
wav = chat.infer(query)
try:
torchaudio.save(audio_path, torch.from_numpy(wav).unsqueeze(0), 24000)
except:
torchaudio.save(audio_path, torch.from_numpy(wav), 24000)
print(f"Human: {query}")
else:
print("Human: <audio>")
speech = whisper.load_audio(audio_path)
speech = whisper.pad_or_trim(speech)
speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0).to(device=model.device, dtype=torch.float16)
speech_length = torch.LongTensor([speech.shape[0]]).to(model.device)
with torch.inference_mode():
output_ids = model.generate(input_ids, images=[video_tensor], modalities=["video"], speeches=speech.unsqueeze(0), speech_lengths=speech_length, **gen_kwargs)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(f"Agent: {outputs}")