prithivMLmods commited on
Commit
554ae5a
·
verified ·
1 Parent(s): 20121ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -19
app.py CHANGED
@@ -12,9 +12,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
12
  from transformers.image_utils import load_image
13
 
14
  #####################################
15
- # 1. Load Qwen2.5-VL Model & Processor
16
  #####################################
17
- MODEL_ID = "google/gemma-3-12b-it" # or "Qwen/Qwen2.5-VL-3B-Instruct"
18
 
19
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -27,7 +27,6 @@ model.eval()
27
  #####################################
28
  # 2. Helper Function: Downsample Video
29
  #####################################
30
-
31
  def downsample_video(video_path, num_frames=10):
32
  """
33
  Downsamples the video file to `num_frames` evenly spaced frames.
@@ -53,6 +52,29 @@ def downsample_video(video_path, num_frames=10):
53
  vidcap.release()
54
  return frames
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  #####################################
57
  # 3. The Inference Function
58
  #####################################
@@ -60,8 +82,8 @@ def downsample_video(video_path, num_frames=10):
60
  def video_inference(video_file, duration):
61
  """
62
  - Takes a recorded video file and a chosen duration (string).
63
- - Downsamples the video, passes frames to Qwen2.5-VL for inference.
64
- - Returns model-generated text + a dummy bar chart as example analytics.
65
  """
66
  if video_file is None:
67
  return "No video provided.", None
@@ -71,23 +93,22 @@ def video_inference(video_file, duration):
71
  if not frames:
72
  return "Could not read frames from video.", None
73
 
74
- # 3.2: Construct Qwen2.5-VL prompt
75
  messages = [
76
  {
77
  "role": "user",
78
  "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
79
  }
80
  ]
81
-
82
  # Add frames (with timestamp) to the messages
83
  for (image, ts) in frames:
84
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
85
  messages[0]["content"].append({"type": "image", "image": image})
86
 
87
- # Prepare final prompt for the model
88
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
 
90
- # Qwen requires images in the same order. We'll just collect them:
91
  frame_images = [img for (img, _) in frames]
92
 
93
  inputs = processor(
@@ -109,14 +130,22 @@ def video_inference(video_file, duration):
109
  generated_text += new_text
110
  time.sleep(0.01)
111
 
112
- # 3.4: Dummy bar chart for demonstration
113
- fig, ax = plt.subplots()
114
- categories = ["Category A", "Category B", "Category C"]
 
 
 
 
115
  values = [random.randint(1, 10) for _ in categories]
116
- ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"])
117
- ax.set_title("Example Analytics Chart")
 
 
 
118
  ax.set_ylabel("Value")
119
- ax.set_xlabel("Category")
 
120
 
121
  return generated_text, fig
122
 
@@ -126,9 +155,9 @@ def video_inference(video_file, duration):
126
  def build_app():
127
  with gr.Blocks() as demo:
128
  gr.Markdown("""
129
- # **Qwen2.5-VL-7B-Instruct Live Video Analysis**
130
  Record a video (from webcam or file), then click **Stop**.
131
- Next, click **Analyze** to run Qwen2.5-VL and see textual + chart outputs.
132
  """)
133
 
134
  with gr.Row():
@@ -139,9 +168,8 @@ def build_app():
139
  label="Suggested Recording Duration (seconds)",
140
  info="Select how long you plan to record before pressing Stop."
141
  )
142
- # Remove 'source="webcam"' to avoid the TypeError on older Gradio versions
143
  video = gr.Video(
144
- label="Webcam Recording (press the Record button, then Stop)",
145
  format="mp4"
146
  )
147
  analyze_btn = gr.Button("Analyze", variant="primary")
 
12
  from transformers.image_utils import load_image
13
 
14
  #####################################
15
+ # 1. Load Gemma3 Model & Processor
16
  #####################################
17
+ MODEL_ID = "google/gemma-3-12b-it" # Example placeholder
18
 
19
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
 
27
  #####################################
28
  # 2. Helper Function: Downsample Video
29
  #####################################
 
30
  def downsample_video(video_path, num_frames=10):
31
  """
32
  Downsamples the video file to `num_frames` evenly spaced frames.
 
52
  vidcap.release()
53
  return frames
54
 
55
+ #####################################
56
+ # 2.5: Parse Categories from Model Output
57
+ #####################################
58
+ def parse_inferred_categories(generated_text):
59
+ """
60
+ A naive parser that looks for lines starting with 'Category:'
61
+ and collects the text after that as the category name.
62
+ Example lines in model output:
63
+ Category: Nutrition
64
+ Category: Outdoor Scenes
65
+ Returns a list of category strings.
66
+ """
67
+ categories = []
68
+ for line in generated_text.split("\n"):
69
+ line = line.strip()
70
+ # Check if the line starts with 'Category:' (case-insensitive)
71
+ if line.lower().startswith("category:"):
72
+ # Extract everything after 'Category:'
73
+ cat = line.split(":", 1)[1].strip()
74
+ if cat:
75
+ categories.append(cat)
76
+ return categories
77
+
78
  #####################################
79
  # 3. The Inference Function
80
  #####################################
 
82
  def video_inference(video_file, duration):
83
  """
84
  - Takes a recorded video file and a chosen duration (string).
85
+ - Downsamples the video, passes frames to the Gemma3 model for inference.
86
+ - Returns model-generated text + a bar chart with categories derived from that text.
87
  """
88
  if video_file is None:
89
  return "No video provided.", None
 
93
  if not frames:
94
  return "Could not read frames from video.", None
95
 
96
+ # 3.2: Construct prompt
97
  messages = [
98
  {
99
  "role": "user",
100
  "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
101
  }
102
  ]
 
103
  # Add frames (with timestamp) to the messages
104
  for (image, ts) in frames:
105
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
106
  messages[0]["content"].append({"type": "image", "image": image})
107
 
108
+ # Prepare final prompt
109
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
 
111
+ # Collect images for model
112
  frame_images = [img for (img, _) in frames]
113
 
114
  inputs = processor(
 
130
  generated_text += new_text
131
  time.sleep(0.01)
132
 
133
+ # 3.4: Parse categories from model output
134
+ categories = parse_inferred_categories(generated_text)
135
+ # If no categories were found, use fallback
136
+ if not categories:
137
+ categories = ["Category A", "Category B", "Category C"]
138
+
139
+ # Create dummy values for each category
140
  values = [random.randint(1, 10) for _ in categories]
141
+
142
+ # 3.5: Create bar chart
143
+ fig, ax = plt.subplots()
144
+ ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]*(len(categories)//3+1))
145
+ ax.set_title("Inferred Categories from Model Output")
146
  ax.set_ylabel("Value")
147
+ ax.set_xlabel("Categories")
148
+ plt.xticks(rotation=30, ha="right")
149
 
150
  return generated_text, fig
151
 
 
155
  def build_app():
156
  with gr.Blocks() as demo:
157
  gr.Markdown("""
158
+ # **Gemma3 (or Qwen2.5-VL) Live Video Analysis**
159
  Record a video (from webcam or file), then click **Stop**.
160
+ Next, click **Analyze** to run the model and see textual + chart outputs.
161
  """)
162
 
163
  with gr.Row():
 
168
  label="Suggested Recording Duration (seconds)",
169
  info="Select how long you plan to record before pressing Stop."
170
  )
 
171
  video = gr.Video(
172
+ label="Webcam Recording (press Record, then Stop)",
173
  format="mp4"
174
  )
175
  analyze_btn = gr.Button("Analyze", variant="primary")