AustingDong commited on
Commit
73c356e
·
1 Parent(s): 6d117d1
Files changed (3) hide show
  1. app.py +86 -212
  2. demo/model_utils.py +29 -12
  3. demo/visualization.py +25 -26
app.py CHANGED
@@ -22,14 +22,15 @@ def set_seed(model_seed = 42):
22
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
23
 
24
  set_seed()
25
- clip_utils = Clip_Utils()
26
- clip_utils.init_Clip()
27
  model_utils, vl_gpt, tokenizer = None, None, None
28
- model_name = "Clip"
 
 
 
 
29
  language_model_max_layer = 24
30
- language_model_best_layer_min = 8
31
- language_model_best_layer_max = 8
32
- vision_model_best_layer = 24
33
 
34
  def clean():
35
  global model_utils, vl_gpt, tokenizer, clip_utils
@@ -71,123 +72,83 @@ def multimodal_understanding(model_type,
71
 
72
  input_text_decoded = ""
73
  answer = ""
74
- if model_name == "Clip":
75
 
76
- inputs = clip_utils.prepare_inputs([question], image)
 
77
 
78
 
79
- if activation_map_method == "GradCAM":
80
- # Generate Grad-CAM
81
- all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
82
 
83
- if visualization_layer_min != visualization_layer_max:
84
- target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
85
- else:
86
- target_layers = [all_layers[visualization_layer_min-1]]
87
- grad_cam = VisualizationClip(clip_utils.model, target_layers)
88
- cam, outputs, grid_size = grad_cam.generate_cam(inputs, target_token_idx=0, visual_method=visual_method)
89
- cam = cam.to("cpu")
90
- cam = [generate_gradcam(cam, image, size=(224, 224))]
91
- grad_cam.remove_hooks()
92
- target_token_decoded = ""
93
-
94
-
95
-
96
- else:
97
-
98
- for param in vl_gpt.parameters():
99
- param.requires_grad = True
100
 
 
 
 
 
101
 
102
- prepare_inputs = model_utils.prepare_inputs(question, image)
 
103
 
104
- if response_type == "answer + visualization":
105
- if model_name.split('-')[0] == "Janus":
106
- inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
107
- outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
108
- else:
109
- outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
110
-
111
- sequences = outputs.sequences.cpu().tolist()
112
- answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
113
- attention_raw = outputs.attentions
114
- print("answer generated")
115
-
116
- input_ids = prepare_inputs.input_ids[0].cpu().tolist()
117
- input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
118
-
119
- if activation_map_method == "GradCAM":
120
- # target_layers = vl_gpt.vision_model.vision_tower.blocks
121
- if focus == "Visual Encoder":
122
- if model_name.split('-')[0] == "Janus":
123
- all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
124
- else:
125
- all_layers = [block.layer_norm1 for block in vl_gpt.vision_tower.vision_model.encoder.layers]
126
- else:
127
- all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
128
-
129
- print("layer values:", visualization_layer_min, visualization_layer_max)
130
- if visualization_layer_min != visualization_layer_max:
131
- print("multi layers")
132
- target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max]
133
- else:
134
- print("single layer")
135
- target_layers = [all_layers[visualization_layer_min-1]]
136
-
137
 
138
- if model_name.split('-')[0] == "Janus":
139
- gradcam = VisualizationJanus(vl_gpt, target_layers)
140
- elif model_name.split('-')[0] == "LLaVA":
141
- gradcam = VisualizationLLaVA(vl_gpt, target_layers)
142
- elif model_name.split('-')[0] == "ChartGemma":
143
- gradcam = VisualizationChartGemma(vl_gpt, target_layers)
144
 
145
- start = 0
146
- cam = []
147
- if focus == "Visual Encoder":
148
- if target_token_idx != -1:
149
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_method, focus)
150
- cam_grid = cam_tensors.reshape(grid_size, grid_size)
151
- cam_i = generate_gradcam(cam_grid, image)
152
- cam_i = add_title_to_image(cam_i, input_ids_decoded[start + target_token_idx])
153
- cam = [cam_i]
154
- else:
155
- i = 0
156
- cam = []
157
- while start + i < len(input_ids_decoded):
158
- if model_name.split('-')[0] == "Janus":
159
- gradcam = VisualizationJanus(vl_gpt, target_layers)
160
- elif model_name.split('-')[0] == "LLaVA":
161
- gradcam = VisualizationLLaVA(vl_gpt, target_layers)
162
- elif model_name.split('-')[0] == "ChartGemma":
163
- gradcam = VisualizationChartGemma(vl_gpt, target_layers)
164
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, i, visual_method, focus, accumulate_method)
165
- cam_grid = cam_tensors.reshape(grid_size, grid_size)
166
- cam_i = generate_gradcam(cam_grid, image)
167
- cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
168
- cam.append(cam_i)
169
- gradcam.remove_hooks()
170
- i += 1
171
  else:
172
- cam_tensors, grid_size, start = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_method, focus, accumulate_method)
173
- if target_token_idx != -1:
174
- input_text_decoded = input_ids_decoded[start + target_token_idx]
175
- for i, cam_tensor in enumerate(cam_tensors):
176
- if i == target_token_idx:
177
- cam_grid = cam_tensor.reshape(grid_size, grid_size)
178
- cam_i = generate_gradcam(cam_grid, image)
179
- cam = [add_title_to_image(cam_i, input_text_decoded)]
180
- break
181
- else:
182
- cam = []
183
- for i, cam_tensor in enumerate(cam_tensors):
184
- cam_grid = cam_tensor.reshape(grid_size, grid_size)
185
- cam_i = generate_gradcam(cam_grid, image)
186
- cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
187
-
188
- cam.append(cam_i)
189
-
190
- gradcam.remove_hooks()
 
 
 
 
 
191
 
192
 
193
  # Collect Results
@@ -219,34 +180,7 @@ def model_slider_change(model_type):
219
  global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer_min, language_model_best_layer_max, vision_model_best_layer
220
  model_name = model_type
221
 
222
-
223
- encoder_only_res = [
224
- gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type"),
225
- gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus"),
226
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
227
- gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
228
- ]
229
-
230
- language_res = [
231
- gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
232
- gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
233
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
234
- gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
235
- ]
236
-
237
-
238
- if model_type == "Clip":
239
- clean()
240
- set_seed()
241
- clip_utils = Clip_Utils()
242
- clip_utils.init_Clip()
243
- sliders = [
244
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
245
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
246
- ]
247
- return tuple(encoder_only_res + sliders)
248
-
249
- elif model_type.split('-')[0] == "Janus":
250
  # best seed: 70
251
  clean()
252
  set_seed()
@@ -263,7 +197,7 @@ def model_slider_change(model_type):
263
  gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_min, step=1, label="visualization layers min"),
264
  gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_max, step=1, label="visualization layers max"),
265
  ]
266
- return tuple(language_res + sliders)
267
 
268
  elif model_type.split('-')[0] == "LLaVA":
269
 
@@ -280,7 +214,7 @@ def model_slider_change(model_type):
280
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
281
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
282
  ]
283
- return tuple(language_res + sliders)
284
 
285
  elif model_type.split('-')[0] == "ChartGemma":
286
  clean()
@@ -290,62 +224,16 @@ def model_slider_change(model_type):
290
  for layer in vl_gpt.language_model.model.layers:
291
  layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
292
  language_model_max_layer = 18
293
- vision_model_best_layer = 19
294
- language_model_best_layer_min = 11
295
  language_model_best_layer_max = 15
296
 
297
  sliders = [
298
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
299
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
300
  ]
301
- return tuple(language_res + sliders)
302
-
303
-
304
 
305
 
306
- def focus_change(focus):
307
- global model_name, language_model_max_layer
308
- if model_name == "Clip":
309
- res = (
310
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
311
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
312
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
313
- )
314
- return res
315
-
316
- if focus == "Language Model":
317
- if response_type.value == "answer + visualization":
318
- res = (
319
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
320
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
321
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max")
322
- )
323
- return res
324
- else:
325
- res = (
326
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
327
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
328
- gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max")
329
- )
330
- return res
331
-
332
- else:
333
- if model_name.split('-')[0] == "ChartGemma":
334
- res = (
335
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
336
- gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers min"),
337
- gr.Slider(minimum=1, maximum=26, value=vision_model_best_layer, step=1, label="visualization layers max")
338
- )
339
- return res
340
-
341
- else:
342
- res = (
343
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type"),
344
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
345
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
346
- )
347
- return res
348
-
349
 
350
  def test_change(test_selector):
351
  if test_selector == "mini-VLAT":
@@ -376,7 +264,7 @@ with gr.Blocks() as demo:
376
  with gr.Row():
377
 
378
  with gr.Column():
379
- model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
380
  test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old"], value="mini-VLAT", label="test")
381
  question_input = gr.Textbox(label="Input Prompt")
382
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
@@ -386,15 +274,15 @@ with gr.Blocks() as demo:
386
 
387
 
388
  with gr.Column():
389
- response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
390
- focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
391
- activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="visualization type")
392
  accumulate_method = gr.Dropdown(choices=["sum", "mult"], value="sum", label="layers accumulate method")
393
- visual_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
394
 
395
 
396
- visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
397
- visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
398
 
399
 
400
 
@@ -404,24 +292,10 @@ with gr.Blocks() as demo:
404
  fn=model_slider_change,
405
  inputs=model_selector,
406
  outputs=[
407
- response_type,
408
- focus,
409
- activation_map_method,
410
- visual_method,
411
  visualization_layers_min,
412
  visualization_layers_max
413
  ]
414
  )
415
-
416
- focus.change(
417
- fn = focus_change,
418
- inputs = focus,
419
- outputs=[
420
- activation_map_method,
421
- visualization_layers_min,
422
- visualization_layers_max,
423
- ]
424
- )
425
 
426
 
427
 
 
22
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
23
 
24
  set_seed()
 
 
25
  model_utils, vl_gpt, tokenizer = None, None, None
26
+ model_utils = ChartGemma_Utils()
27
+ vl_gpt, tokenizer = model_utils.init_ChartGemma()
28
+ for layer in vl_gpt.language_model.model.layers:
29
+ layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
30
+ model_name = "ChartGemma-3B"
31
  language_model_max_layer = 24
32
+ language_model_best_layer_min = 9
33
+ language_model_best_layer_max = 15
 
34
 
35
  def clean():
36
  global model_utils, vl_gpt, tokenizer, clip_utils
 
72
 
73
  input_text_decoded = ""
74
  answer = ""
75
+
76
 
77
+ for param in vl_gpt.parameters():
78
+ param.requires_grad = True
79
 
80
 
81
+ prepare_inputs = model_utils.prepare_inputs(question, image)
 
 
82
 
83
+ if response_type == "answer + visualization":
84
+ if model_name.split('-')[0] == "Janus":
85
+ inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
86
+ outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
87
+ else:
88
+ outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ sequences = outputs.sequences.cpu().tolist()
91
+ answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
92
+ attention_raw = outputs.attentions
93
+ print("answer generated")
94
 
95
+ input_ids = prepare_inputs.input_ids[0].cpu().tolist()
96
+ input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
97
 
98
+ if activation_map_method == "AG-CAM":
99
+ # target_layers = vl_gpt.vision_model.vision_tower.blocks
100
+
101
+ all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
102
+
103
+ print("layer values:", visualization_layer_min, visualization_layer_max)
104
+ if visualization_layer_min != visualization_layer_max:
105
+ print("multi layers")
106
+ target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max]
107
+ else:
108
+ print("single layer")
109
+ target_layers = [all_layers[visualization_layer_min-1]]
110
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ if model_name.split('-')[0] == "Janus":
113
+ gradcam = VisualizationJanus(vl_gpt, target_layers)
114
+ elif model_name.split('-')[0] == "LLaVA":
115
+ gradcam = VisualizationLLaVA(vl_gpt, target_layers)
116
+ elif model_name.split('-')[0] == "ChartGemma":
117
+ gradcam = VisualizationChartGemma(vl_gpt, target_layers)
118
 
119
+ start = 0
120
+ cam = []
121
+
122
+ # utilize the entire sequence, including <image>s, question, and answer
123
+ entire_inputs = prepare_inputs
124
+ if response_type == "answer + visualization" and focus == "question + answer":
125
+ if model_name.split('-')[0] == "Janus" or model_name.split('-')[0] == "LLaVA":
126
+ entire_inputs = model_utils.prepare_inputs(question, image, answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
+ entire_inputs["input_ids"] = outputs.sequences
129
+ entire_inputs["attention_mask"] = torch.ones_like(outputs.sequences)
130
+ input_ids = entire_inputs['input_ids'][0].cpu().tolist()
131
+ input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
132
+
133
+ cam_tensors, grid_size, start = gradcam.generate_cam(entire_inputs, tokenizer, temperature, top_p, target_token_idx, visual_method, "Language Model", accumulate_method)
134
+ if target_token_idx != -1:
135
+ input_text_decoded = input_ids_decoded[start + target_token_idx]
136
+ for i, cam_tensor in enumerate(cam_tensors):
137
+ if i == target_token_idx:
138
+ cam_grid = cam_tensor.reshape(grid_size, grid_size)
139
+ cam_i = generate_gradcam(cam_grid, image)
140
+ cam = [add_title_to_image(cam_i, input_text_decoded)]
141
+ break
142
+ else:
143
+ cam = []
144
+ for i, cam_tensor in enumerate(cam_tensors):
145
+ cam_grid = cam_tensor.reshape(grid_size, grid_size)
146
+ cam_i = generate_gradcam(cam_grid, image)
147
+ cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
148
+
149
+ cam.append(cam_i)
150
+
151
+ gradcam.remove_hooks()
152
 
153
 
154
  # Collect Results
 
180
  global model_utils, vl_gpt, tokenizer, clip_utils, model_name, language_model_max_layer, language_model_best_layer_min, language_model_best_layer_max, vision_model_best_layer
181
  model_name = model_type
182
 
183
+ if model_type.split('-')[0] == "Janus":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # best seed: 70
185
  clean()
186
  set_seed()
 
197
  gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_min, step=1, label="visualization layers min"),
198
  gr.Slider(minimum=1, maximum=24, value=language_model_best_layer_max, step=1, label="visualization layers max"),
199
  ]
200
+ return tuple(sliders)
201
 
202
  elif model_type.split('-')[0] == "LLaVA":
203
 
 
214
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
215
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
216
  ]
217
+ return tuple(sliders)
218
 
219
  elif model_type.split('-')[0] == "ChartGemma":
220
  clean()
 
224
  for layer in vl_gpt.language_model.model.layers:
225
  layer.self_attn = ModifiedGemmaAttention(layer.self_attn)
226
  language_model_max_layer = 18
227
+ language_model_best_layer_min = 9
 
228
  language_model_best_layer_max = 15
229
 
230
  sliders = [
231
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_min, step=1, label="visualization layers min"),
232
  gr.Slider(minimum=1, maximum=language_model_max_layer, value=language_model_best_layer_max, step=1, label="visualization layers max"),
233
  ]
234
+ return tuple(sliders)
 
 
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  def test_change(test_selector):
239
  if test_selector == "mini-VLAT":
 
264
  with gr.Row():
265
 
266
  with gr.Column():
267
+ model_selector = gr.Dropdown(choices=["ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="ChartGemma-3B", label="model")
268
  test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old"], value="mini-VLAT", label="test")
269
  question_input = gr.Textbox(label="Input Prompt")
270
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
 
274
 
275
 
276
  with gr.Column():
277
+ response_type = gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type")
278
+ focus = gr.Dropdown(choices=["question", "question + answer"], value="question + answer", label="focus")
279
+ activation_map_method = gr.Dropdown(choices=["AG-CAM"], value="AG-CAM", label="visualization type")
280
  accumulate_method = gr.Dropdown(choices=["sum", "mult"], value="sum", label="layers accumulate method")
281
+ visual_method = gr.Dropdown(choices=["softmax", "sigmoid"], value="softmax", label="activation function")
282
 
283
 
284
+ visualization_layers_min = gr.Slider(minimum=1, maximum=18, value=11, step=1, label="visualization layers min")
285
+ visualization_layers_max = gr.Slider(minimum=1, maximum=18, value=15, step=1, label="visualization layers max")
286
 
287
 
288
 
 
292
  fn=model_slider_change,
293
  inputs=model_selector,
294
  outputs=[
 
 
 
 
295
  visualization_layers_min,
296
  visualization_layers_max
297
  ]
298
  )
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
 
demo/model_utils.py CHANGED
@@ -74,14 +74,14 @@ class Janus_Utils(Model_Utils):
74
  return self.vl_gpt, self.tokenizer
75
 
76
  @spaces.GPU(duration=120)
77
- def prepare_inputs(self, question, image):
78
  conversation = [
79
  {
80
  "role": "<|User|>",
81
  "content": f"<image_placeholder>\n{question}",
82
  "images": [image],
83
  },
84
- {"role": "<|Assistant|>", "content": ""},
85
  ]
86
 
87
  pil_images = [Image.fromarray(image)]
@@ -152,16 +152,33 @@ class LLaVA_Utils(Model_Utils):
152
  return self.vl_gpt, self.tokenizer
153
 
154
  @spaces.GPU(duration=120)
155
- def prepare_inputs(self, question, image):
156
- conversation = [
157
- {
158
- "role": "user",
159
- "content": [
160
- {"type": "text", "text": question},
161
- {"type": "image"},
162
- ],
163
- },
164
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
167
  pil_images = [Image.fromarray(image).resize((384, 384))]
 
74
  return self.vl_gpt, self.tokenizer
75
 
76
  @spaces.GPU(duration=120)
77
+ def prepare_inputs(self, question, image, answer=None):
78
  conversation = [
79
  {
80
  "role": "<|User|>",
81
  "content": f"<image_placeholder>\n{question}",
82
  "images": [image],
83
  },
84
+ {"role": "<|Assistant|>", "content": answer if answer else ""}
85
  ]
86
 
87
  pil_images = [Image.fromarray(image)]
 
152
  return self.vl_gpt, self.tokenizer
153
 
154
  @spaces.GPU(duration=120)
155
+ def prepare_inputs(self, question, image, answer=None):
156
+ if answer:
157
+ conversation = [
158
+ {
159
+ "role": "user",
160
+ "content": [
161
+ {"type": "text", "text": question},
162
+ {"type": "image"},
163
+ ],
164
+ },
165
+ {
166
+ "role": "assistant",
167
+ "content": [
168
+ {"type": "text", "text": answer},
169
+ ],
170
+ }
171
+ ]
172
+ else:
173
+ conversation = [
174
+ {
175
+ "role": "user",
176
+ "content": [
177
+ {"type": "text", "text": question},
178
+ {"type": "image"},
179
+ ],
180
+ },
181
+ ]
182
 
183
  prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
184
  pil_images = [Image.fromarray(image).resize((384, 384))]
demo/visualization.py CHANGED
@@ -25,7 +25,7 @@ class Visualization:
25
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
 
27
  def _forward_hook(self, module, input, output):
28
- print("forward_hook: self_attn_input: ", input)
29
  self.activations.append(output)
30
 
31
  def _backward_hook(self, module, grad_in, grad_out):
@@ -42,12 +42,12 @@ class Visualization:
42
  layer.get_attn_map = types.MethodType(get_attn_map, layer)
43
 
44
  def _forward_activate_hooks(self, module, input, output):
45
- print("forward_activate_hool: module: ", module)
46
- print("forward_activate_hook: self_attn_input: ", input)
47
 
48
  attn_output, attn_weights = output # Unpack outputs
49
- print("attn_output shape:", attn_output.shape)
50
- print("attn_weights shape:", attn_weights.shape)
51
  module.save_attn_map(attn_weights)
52
  attn_weights.register_hook(module.save_attn_gradients)
53
 
@@ -137,8 +137,10 @@ class Visualization:
137
 
138
  grad = F.relu(grad)
139
 
 
140
  # cam = grad
141
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
 
142
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
143
  cam = cam.to(torch.float32).detach().cpu()
144
  cams.append(cam)
@@ -187,7 +189,6 @@ class Visualization:
187
  # print("cam_sum shape: ", cam_sum.shape)
188
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
189
  grid_size = int(num_patches ** 0.5)
190
- # print(f"Detected grid size: {grid_size}x{grid_size}")
191
 
192
  cam_sum = cam_sum.view(grid_size, grid_size)
193
  if normalize:
@@ -207,7 +208,6 @@ class Visualization:
207
 
208
  num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
209
  grid_size = int(num_patches ** 0.5)
210
- # print(f"Detected grid size: {grid_size}x{grid_size}")
211
 
212
  # Fix the reshaping step dynamically
213
  cam_reshaped = cam_l_i.view(grid_size, grid_size)
@@ -258,7 +258,6 @@ class VisualizationClip(Visualization):
258
 
259
  @spaces.GPU(duration=120)
260
  def generate_cam(self, input_tensor, target_token_idx=None, visual_method="CLS"):
261
- """ Generates Grad-CAM heatmap for ViT. """
262
  self.setup_grads()
263
  # Forward Backward pass
264
  output_full = self.forward_backward(input_tensor, visual_method, target_token_idx)
@@ -301,9 +300,13 @@ class VisualizationJanus(Visualization):
301
  def forward_backward(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
302
  # Forward
303
  image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
304
- input_ids = input_tensor.input_ids
 
305
  start_idx = 620
306
  self.model.zero_grad()
 
 
 
307
  if focus == "Visual Encoder":
308
  loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
309
  loss.backward()
@@ -335,11 +338,15 @@ class VisualizationJanus(Visualization):
335
 
336
  elif focus == "Language Model":
337
 
338
- cam_sum = self.grad_cam_llm(mean_inside=True)
339
 
340
- images_seq_mask = input_tensor.images_seq_mask
341
 
342
- cam_sum_lst, grid_size = self.process_multiple(cam_sum, start_idx, images_seq_mask)
 
 
 
 
343
 
344
  return cam_sum_lst, grid_size, start_idx
345
 
@@ -407,15 +414,6 @@ class VisualizationChartGemma(Visualization):
407
  self._modify_layers()
408
  self._register_hooks_activations()
409
 
410
- # def custom_loss(self, start_idx, input_ids, logits):
411
- # Q = logits.shape[1]
412
- # loss = 0
413
- # q = 0
414
- # while start_idx + q < Q - 1:
415
- # loss += F.cross_entropy(logits[0, start_idx + q], input_ids[0, start_idx + q + 1])
416
- # q += 1
417
- # return loss
418
-
419
 
420
  def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
421
  outputs_raw = self.model(**inputs, output_hidden_states=True)
@@ -429,9 +427,11 @@ class VisualizationChartGemma(Visualization):
429
  elif focus == "Language Model":
430
  self.model.zero_grad()
431
  print("logits shape:", outputs_raw.logits.shape)
 
432
  if target_token_idx == -1:
433
- loss = outputs_raw.logits.max(dim=-1).values.sum()
434
- # loss = self.custom_loss(start_idx, inputs['input_ids'], outputs_raw.logits)
 
435
  else:
436
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
437
  loss.backward()
@@ -495,7 +495,7 @@ def generate_gradcam(
495
  normalize=False
496
  ):
497
  """
498
- Generates a Grad-CAM heatmap overlay on top of the input image.
499
 
500
  Parameters:
501
  cam (torch.Tensor): A tensor of shape (C, H, W) representing the
@@ -508,9 +508,8 @@ def generate_gradcam(
508
  normalize (bool): Whether to normalize the heatmap (default False).
509
 
510
  Returns:
511
- PIL.Image: The image overlaid with the Grad-CAM heatmap.
512
  """
513
- # print("Generating Grad-CAM with shape:", cam.shape)
514
 
515
  if normalize:
516
  cam_min, cam_max = cam.min(), cam.max()
 
25
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
 
27
  def _forward_hook(self, module, input, output):
28
+ # print("forward_hook: self_attn_input: ", input)
29
  self.activations.append(output)
30
 
31
  def _backward_hook(self, module, grad_in, grad_out):
 
42
  layer.get_attn_map = types.MethodType(get_attn_map, layer)
43
 
44
  def _forward_activate_hooks(self, module, input, output):
45
+ # print("forward_activate_hool: module: ", module)
46
+ # print("forward_activate_hook: self_attn_input: ", input)
47
 
48
  attn_output, attn_weights = output # Unpack outputs
49
+ # print("attn_output shape:", attn_output.shape)
50
+ # print("attn_weights shape:", attn_weights.shape)
51
  module.save_attn_map(attn_weights)
52
  attn_weights.register_hook(module.save_attn_gradients)
53
 
 
137
 
138
  grad = F.relu(grad)
139
 
140
+ # cam = act
141
  # cam = grad
142
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
143
+
144
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
145
  cam = cam.to(torch.float32).detach().cpu()
146
  cams.append(cam)
 
189
  # print("cam_sum shape: ", cam_sum.shape)
190
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
191
  grid_size = int(num_patches ** 0.5)
 
192
 
193
  cam_sum = cam_sum.view(grid_size, grid_size)
194
  if normalize:
 
208
 
209
  num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
210
  grid_size = int(num_patches ** 0.5)
 
211
 
212
  # Fix the reshaping step dynamically
213
  cam_reshaped = cam_l_i.view(grid_size, grid_size)
 
258
 
259
  @spaces.GPU(duration=120)
260
  def generate_cam(self, input_tensor, target_token_idx=None, visual_method="CLS"):
 
261
  self.setup_grads()
262
  # Forward Backward pass
263
  output_full = self.forward_backward(input_tensor, visual_method, target_token_idx)
 
300
  def forward_backward(self, input_tensor, tokenizer, temperature, top_p, target_token_idx=None, visual_method="softmax", focus="Visual Encoder"):
301
  # Forward
302
  image_embeddings, inputs_embeddings, outputs = self.model(input_tensor, tokenizer, temperature, top_p)
303
+ print(input_tensor.keys())
304
+ input_ids = input_tensor["input_ids"]
305
  start_idx = 620
306
  self.model.zero_grad()
307
+
308
+
309
+
310
  if focus == "Visual Encoder":
311
  loss = outputs.logits.max(dim=-1).values[0, start_idx + target_token_idx]
312
  loss.backward()
 
338
 
339
  elif focus == "Language Model":
340
 
341
+ # cam_sum = self.grad_cam_llm(mean_inside=True)
342
 
343
+ images_seq_mask = input_tensor.images_seq_mask[0].detach().cpu().tolist()
344
 
345
+ # cam_sum_lst, grid_size = self.process_multiple(cam_sum, start_idx, images_seq_mask)
346
+
347
+ cams = self.attn_guided_cam()
348
+ cam_sum_lst, grid_size = self.process_multiple_acc(cams, start_idx, images_seq_mask, accumulate_method=accumulate_method)
349
+
350
 
351
  return cam_sum_lst, grid_size, start_idx
352
 
 
414
  self._modify_layers()
415
  self._register_hooks_activations()
416
 
 
 
 
 
 
 
 
 
 
417
 
418
  def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
419
  outputs_raw = self.model(**inputs, output_hidden_states=True)
 
427
  elif focus == "Language Model":
428
  self.model.zero_grad()
429
  print("logits shape:", outputs_raw.logits.shape)
430
+ print("start_idx:", start_idx)
431
  if target_token_idx == -1:
432
+ logits_prob = F.softmax(outputs_raw.logits, dim=-1)
433
+ loss = logits_prob.max(dim=-1).values.sum()
434
+
435
  else:
436
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
437
  loss.backward()
 
495
  normalize=False
496
  ):
497
  """
498
+ Generates a heatmap overlay on top of the input image.
499
 
500
  Parameters:
501
  cam (torch.Tensor): A tensor of shape (C, H, W) representing the
 
508
  normalize (bool): Whether to normalize the heatmap (default False).
509
 
510
  Returns:
511
+ PIL.Image: The image overlaid with the heatmap.
512
  """
 
513
 
514
  if normalize:
515
  cam_min, cam_max = cam.min(), cam.max()