import os import random import uuid import json import time import asyncio from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image import cv2 from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, Qwen2VLForConditionalGeneration, AutoProcessor, ) from transformers.image_utils import load_image # Constants for text generation MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load text-only model and tokenizer model_id = "prithivMLmods/Pocket-Llama-3.2-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.eval() MODEL_ID = "prithivMLmods/Callisto-OCR3-2B-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model_m = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() def clean_chat_history(chat_history): """ Filter out any chat entries whose "content" is not a string. This helps prevent errors when concatenating previous messages. """ cleaned = [] for msg in chat_history: if isinstance(msg, dict) and isinstance(msg.get("content"), str): cleaned.append(msg) return cleaned def downsample_video(video_path): """ Downsamples the video to 10 evenly spaced frames. Each frame is returned as a PIL image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] # Sample 10 evenly spaced frames. frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def progress_bar_html(label: str) -> str: """ Returns an HTML snippet for a thin progress bar with a label. The progress bar is styled as a dark red animated bar. """ return f'''
{label}
''' @spaces.GPU(duration=60, enable_queue=True) def generate(input_dict: dict, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """ Generates chatbot responses with support for multimodal input and video processing. Special command: - "@video-infer": triggers video processing using Qwen2VL. """ text = input_dict["text"] files = input_dict.get("files", []) lower_text = text.strip().lower() # Branch for video processing with Qwen2VL. if lower_text.startswith("@video-infer"): prompt = text[len("@video-infer"):].strip() if files: # Assume the first file is a video. video_path = files[0] frames = downsample_video(video_path) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "text", "text": prompt}]} ] # Append each frame with its timestamp. for frame in frames: image, timestamp = frame image_path = f"video_frame_{uuid.uuid4().hex}.png" image.save(image_path) messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) messages[1]["content"].append({"type": "image", "url": image_path}) else: messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "text", "text": prompt}]} ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Processing video with Qwen2VL") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer return # Normal text or multimodal conversation processing. if files: if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model_m.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Thinking...") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer else: conversation = clean_chat_history(chat_history) conversation.append({"role": "user", "content": text}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, } t = Thread(target=model.generate, kwargs=generation_kwargs) t.start() outputs = [] yield progress_bar_html("Processing...") for new_text in streamer: outputs.append(new_text) yield "".join(outputs) final_response = "".join(outputs) yield final_response # Create the Gradio ChatInterface with the custom CSS applied demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS), gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6), gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50), gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2), ], examples=[ ["Write the code that converts temperatures between celsius and fahrenheit"], [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}], ], cache_examples=False, type="messages", fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, theme="YTheme/GMaterial", ) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)