Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import random | |
import spaces | |
import time | |
from PIL import Image | |
from threading import Thread | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
from transformers.image_utils import load_image | |
##################################### | |
# 1. Load Gemma3 Model & Processor | |
##################################### | |
MODEL_ID = "google/gemma-3-12b-it" # Example placeholder | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
model.eval() | |
##################################### | |
# 2. Helper Function: Downsample Video | |
##################################### | |
def downsample_video(video_path, num_frames=10): | |
""" | |
Downsamples the video file to `num_frames` evenly spaced frames. | |
Each frame is converted to a PIL Image along with its timestamp. | |
""" | |
vidcap = cv2.VideoCapture(video_path) | |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
frames = [] | |
if total_frames <= 0 or fps <= 0: | |
vidcap.release() | |
return frames | |
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
for i in frame_indices: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
success, image = vidcap.read() | |
if success: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(image) | |
timestamp = round(i / fps, 2) | |
frames.append((pil_image, timestamp)) | |
vidcap.release() | |
return frames | |
##################################### | |
# 2.5: Parse Categories from Model Output | |
##################################### | |
def parse_inferred_categories(generated_text): | |
""" | |
A naive parser that looks for lines starting with 'Category:' | |
and collects the text after that as the category name. | |
Example lines in model output: | |
Category: Nutrition | |
Category: Outdoor Scenes | |
Returns a list of category strings. | |
""" | |
categories = [] | |
for line in generated_text.split("\n"): | |
line = line.strip() | |
# Check if the line starts with 'Category:' (case-insensitive) | |
if line.lower().startswith("category:"): | |
# Extract everything after 'Category:' | |
cat = line.split(":", 1)[1].strip() | |
if cat: | |
categories.append(cat) | |
return categories | |
##################################### | |
# 3. The Inference Function | |
##################################### | |
def video_inference(video_file, duration): | |
""" | |
- Takes a recorded video file and a chosen duration (string). | |
- Downsamples the video, passes frames to the Gemma3 model for inference. | |
- Returns model-generated text + a bar chart with categories derived from that text. | |
""" | |
if video_file is None: | |
return "No video provided.", None | |
# 3.1: Downsample the recorded video | |
frames = downsample_video(video_file) | |
if not frames: | |
return "Could not read frames from video.", None | |
# 3.2: Construct prompt | |
messages = [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": "Please describe what's happening in this video."}] | |
} | |
] | |
# Add frames (with timestamp) to the messages | |
for (image, ts) in frames: | |
messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) | |
messages[0]["content"].append({"type": "image", "image": image}) | |
# Prepare final prompt | |
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Collect images for model | |
frame_images = [img for (img, _) in frames] | |
inputs = processor( | |
text=[prompt], | |
images=frame_images, | |
return_tensors="pt", | |
padding=True | |
).to("cuda") | |
# 3.3: Generate text output (streaming) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
time.sleep(0.01) | |
# 3.4: Parse categories from model output | |
categories = parse_inferred_categories(generated_text) | |
# If no categories were found, use fallback | |
if not categories: | |
categories = ["Category A", "Category B", "Category C"] | |
# Create dummy values for each category | |
values = [random.randint(1, 10) for _ in categories] | |
# 3.5: Create bar chart | |
fig, ax = plt.subplots() | |
ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]*(len(categories)//3+1)) | |
ax.set_title("Inferred Categories from Model Output") | |
ax.set_ylabel("Value") | |
ax.set_xlabel("Categories") | |
plt.xticks(rotation=30, ha="right") | |
return generated_text, fig | |
##################################### | |
# 4. Build a Professional Gradio UI | |
##################################### | |
def build_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# **Gemma3 (or Qwen2.5-VL) Live Video Analysis** | |
Record a video (from webcam or file), then click **Stop**. | |
Next, click **Analyze** to run the model and see textual + chart outputs. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
duration = gr.Radio( | |
choices=["5", "10", "20", "30"], | |
value="5", | |
label="Suggested Recording Duration (seconds)", | |
info="Select how long you plan to record before pressing Stop." | |
) | |
video = gr.Video( | |
label="Webcam Recording (press Record, then Stop)", | |
format="mp4" | |
) | |
analyze_btn = gr.Button("Analyze", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Model Output") | |
output_plot = gr.Plot(label="Analytics Chart") | |
analyze_btn.click( | |
fn=video_inference, | |
inputs=[video, duration], | |
outputs=[output_text, output_plot] | |
) | |
return demo | |
if __name__ == "__main__": | |
app = build_app() | |
app.launch(debug=True) |