import gradio as gr import torch import numpy as np import cv2 import matplotlib.pyplot as plt import random import spaces import time from PIL import Image from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer from transformers.image_utils import load_image ##################################### # 1. Load Gemma3 Model & Processor ##################################### MODEL_ID = "google/gemma-3-12b-it" # Example placeholder processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = Gemma3ForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to("cuda") model.eval() ##################################### # 2. Helper Function: Downsample Video ##################################### def downsample_video(video_path, num_frames=10): """ Downsamples the video file to `num_frames` evenly spaced frames. Each frame is converted to a PIL Image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] if total_frames <= 0 or fps <= 0: vidcap.release() return frames frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames ##################################### # 2.5: Parse Categories from Model Output ##################################### def parse_inferred_categories(generated_text): """ A naive parser that looks for lines starting with 'Category:' and collects the text after that as the category name. Example lines in model output: Category: Nutrition Category: Outdoor Scenes Returns a list of category strings. """ categories = [] for line in generated_text.split("\n"): line = line.strip() # Check if the line starts with 'Category:' (case-insensitive) if line.lower().startswith("category:"): # Extract everything after 'Category:' cat = line.split(":", 1)[1].strip() if cat: categories.append(cat) return categories ##################################### # 3. The Inference Function ##################################### @spaces.GPU def video_inference(video_file, duration): """ - Takes a recorded video file and a chosen duration (string). - Downsamples the video, passes frames to the Gemma3 model for inference. - Returns model-generated text + a bar chart with categories derived from that text. """ if video_file is None: return "No video provided.", None # 3.1: Downsample the recorded video frames = downsample_video(video_file) if not frames: return "Could not read frames from video.", None # 3.2: Construct prompt messages = [ { "role": "user", "content": [{"type": "text", "text": "Please describe what's happening in this video."}] } ] # Add frames (with timestamp) to the messages for (image, ts) in frames: messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) messages[0]["content"].append({"type": "image", "image": image}) # Prepare final prompt prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Collect images for model frame_images = [img for (img, _) in frames] inputs = processor( text=[prompt], images=frame_images, return_tensors="pt", padding=True ).to("cuda") # 3.3: Generate text output (streaming) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text time.sleep(0.01) # 3.4: Parse categories from model output categories = parse_inferred_categories(generated_text) # If no categories were found, use fallback if not categories: categories = ["Category A", "Category B", "Category C"] # Create dummy values for each category values = [random.randint(1, 10) for _ in categories] # 3.5: Create bar chart fig, ax = plt.subplots() ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]*(len(categories)//3+1)) ax.set_title("Inferred Categories from Model Output") ax.set_ylabel("Value") ax.set_xlabel("Categories") plt.xticks(rotation=30, ha="right") return generated_text, fig ##################################### # 4. Build a Professional Gradio UI ##################################### def build_app(): with gr.Blocks() as demo: gr.Markdown(""" # **Gemma3 (or Qwen2.5-VL) Live Video Analysis** Record a video (from webcam or file), then click **Stop**. Next, click **Analyze** to run the model and see textual + chart outputs. """) with gr.Row(): with gr.Column(): duration = gr.Radio( choices=["5", "10", "20", "30"], value="5", label="Suggested Recording Duration (seconds)", info="Select how long you plan to record before pressing Stop." ) video = gr.Video( label="Webcam Recording (press Record, then Stop)", format="mp4" ) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Model Output") output_plot = gr.Plot(label="Analytics Chart") analyze_btn.click( fn=video_inference, inputs=[video, duration], outputs=[output_text, output_plot] ) return demo if __name__ == "__main__": app = build_app() app.launch(debug=True)