File size: 6,492 Bytes
b8a0d2d
8716c2f
 
 
 
 
5255817
8716c2f
 
 
20121ea
8716c2f
 
 
554ae5a
8716c2f
554ae5a
8716c2f
 
5373e26
8716c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554ae5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8716c2f
 
 
2aaebb9
8716c2f
 
 
554ae5a
 
8716c2f
 
 
 
 
 
 
 
 
554ae5a
8716c2f
 
 
 
 
 
 
 
 
 
 
554ae5a
8716c2f
 
554ae5a
8716c2f
 
 
 
 
 
 
 
 
8e6677c
8716c2f
 
 
 
 
 
 
 
 
 
 
554ae5a
 
 
 
 
 
 
8716c2f
554ae5a
 
 
 
 
8716c2f
554ae5a
 
8716c2f
 
 
 
 
 
 
 
 
554ae5a
8e6677c
554ae5a
8716c2f
 
 
 
 
 
 
 
 
 
 
554ae5a
8e6677c
8716c2f
 
 
 
 
 
 
 
 
 
50fda8e
8716c2f
 
b8a0d2d
 
8716c2f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
import spaces
import time
from PIL import Image
from threading import Thread
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image

#####################################
# 1. Load Gemma3 Model & Processor
#####################################
MODEL_ID = "google/gemma-3-12b-it"  # Example placeholder

processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Gemma3ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda")
model.eval()

#####################################
# 2. Helper Function: Downsample Video
#####################################
def downsample_video(video_path, num_frames=10):
    """
    Downsamples the video file to `num_frames` evenly spaced frames.
    Each frame is converted to a PIL Image along with its timestamp.
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames

    frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames

#####################################
# 2.5: Parse Categories from Model Output
#####################################
def parse_inferred_categories(generated_text):
    """
    A naive parser that looks for lines starting with 'Category:' 
    and collects the text after that as the category name.
    Example lines in model output:
        Category: Nutrition
        Category: Outdoor Scenes
    Returns a list of category strings.
    """
    categories = []
    for line in generated_text.split("\n"):
        line = line.strip()
        # Check if the line starts with 'Category:' (case-insensitive)
        if line.lower().startswith("category:"):
            # Extract everything after 'Category:'
            cat = line.split(":", 1)[1].strip()
            if cat:
                categories.append(cat)
    return categories

#####################################
# 3. The Inference Function
#####################################
@spaces.GPU
def video_inference(video_file, duration):
    """
    - Takes a recorded video file and a chosen duration (string).
    - Downsamples the video, passes frames to the Gemma3 model for inference.
    - Returns model-generated text + a bar chart with categories derived from that text.
    """
    if video_file is None:
        return "No video provided.", None

    # 3.1: Downsample the recorded video
    frames = downsample_video(video_file)
    if not frames:
        return "Could not read frames from video.", None

    # 3.2: Construct prompt
    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
        }
    ]
    # Add frames (with timestamp) to the messages
    for (image, ts) in frames:
        messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
        messages[0]["content"].append({"type": "image", "image": image})

    # Prepare final prompt
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Collect images for model
    frame_images = [img for (img, _) in frames]

    inputs = processor(
        text=[prompt],
        images=frame_images,
        return_tensors="pt",
        padding=True
    ).to("cuda")

    # 3.3: Generate text output (streaming)
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        time.sleep(0.01)

    # 3.4: Parse categories from model output
    categories = parse_inferred_categories(generated_text)
    # If no categories were found, use fallback
    if not categories:
        categories = ["Category A", "Category B", "Category C"]

    # Create dummy values for each category
    values = [random.randint(1, 10) for _ in categories]

    # 3.5: Create bar chart
    fig, ax = plt.subplots()
    ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]*(len(categories)//3+1))
    ax.set_title("Inferred Categories from Model Output")
    ax.set_ylabel("Value")
    ax.set_xlabel("Categories")
    plt.xticks(rotation=30, ha="right")

    return generated_text, fig

#####################################
# 4. Build a Professional Gradio UI
#####################################
def build_app():
    with gr.Blocks() as demo:
        gr.Markdown("""
        # **Gemma3 (or Qwen2.5-VL) Live Video Analysis**  
        Record a video (from webcam or file), then click **Stop**.  
        Next, click **Analyze** to run the model and see textual + chart outputs.
        """)

        with gr.Row():
            with gr.Column():
                duration = gr.Radio(
                    choices=["5", "10", "20", "30"],
                    value="5",
                    label="Suggested Recording Duration (seconds)",
                    info="Select how long you plan to record before pressing Stop."
                )
                video = gr.Video(
                    label="Webcam Recording (press Record, then Stop)",
                    format="mp4"
                )
                analyze_btn = gr.Button("Analyze", variant="primary")
            with gr.Column():
                output_text = gr.Textbox(label="Model Output")
                output_plot = gr.Plot(label="Analytics Chart")

        analyze_btn.click(
            fn=video_inference,
            inputs=[video, duration],
            outputs=[output_text, output_plot]
        )

    return demo

if __name__ == "__main__":
    app = build_app()
    app.launch(debug=True)