AustingDong commited on
Commit
4db7aa5
·
1 Parent(s): 5da9d34

fixed multi-layer

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. app.py +52 -19
  3. demo/cam.py +35 -96
.gitignore CHANGED
@@ -418,4 +418,6 @@ tags
418
  [._]*.un~
419
  .vscode
420
  .github
421
- generated_samples/
 
 
 
418
  [._]*.un~
419
  .vscode
420
  .github
421
+ generated_samples/
422
+
423
+ results
app.py CHANGED
@@ -9,6 +9,7 @@ from demo.model_utils import Clip_Utils, Janus_Utils, LLaVA_Utils, ChartGemma_Ut
9
  import numpy as np
10
  import matplotlib.pyplot as plt
11
  import gc
 
12
  import spaces
13
  from PIL import Image
14
 
@@ -53,7 +54,7 @@ def multimodal_understanding(model_type,
53
  activation_map_method,
54
  visual_pooling_method,
55
  image, question, seed, top_p, temperature, target_token_idx,
56
- visualization_layer_min, visualization_layer_max, focus, response_type):
57
  # Clear CUDA cache before generating
58
  gc.collect()
59
  if torch.cuda.is_available():
@@ -75,7 +76,8 @@ def multimodal_understanding(model_type,
75
  if activation_map_method == "GradCAM":
76
  # Generate Grad-CAM
77
  all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
78
- if visualization_layers_min.value != visualization_layers_max.value:
 
79
  target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
80
  else:
81
  target_layers = [all_layers[visualization_layer_min-1]]
@@ -110,12 +112,6 @@ def multimodal_understanding(model_type,
110
 
111
  input_ids = prepare_inputs.input_ids[0].cpu().tolist()
112
  input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
113
- # if model_name.split('-')[0] == "Janus":
114
- # start = 620
115
- # elif model_name.split('-')[0] == "ChartGemma":
116
- # start = 1024
117
- # elif model_name.split('-')[0] == "LLaVA":
118
- # start = 581
119
 
120
  if activation_map_method == "GradCAM":
121
  # target_layers = vl_gpt.vision_model.vision_tower.blocks
@@ -123,11 +119,15 @@ def multimodal_understanding(model_type,
123
  all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
124
  else:
125
  all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
126
-
127
- if visualization_layers_min.value != visualization_layers_max.value:
128
- target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
 
 
129
  else:
 
130
  target_layers = [all_layers[visualization_layer_min-1]]
 
131
 
132
  if model_name.split('-')[0] == "Janus":
133
  gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
@@ -165,6 +165,26 @@ def multimodal_understanding(model_type,
165
 
166
  cam.append(cam_i)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return answer, cam, input_text_decoded
169
 
170
 
@@ -235,8 +255,8 @@ def model_slider_change(model_type):
235
 
236
  res = (
237
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
238
- gr.Slider(minimum=1, maximum=18, value=15, step=1, label="visualization layers min"),
239
- gr.Slider(minimum=1, maximum=18, value=15, step=1, label="visualization layers max"),
240
  gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
241
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
242
  )
@@ -291,7 +311,7 @@ with gr.Blocks() as demo:
291
  activation_map_output = gr.Gallery(label="activation Map", height=300, columns=1)
292
 
293
  with gr.Column():
294
- model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-1B", "Janus-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
295
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
296
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
297
  activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
@@ -339,7 +359,8 @@ with gr.Blocks() as demo:
339
 
340
 
341
 
342
- understanding_button = gr.Button("Chat")
 
343
  understanding_output = gr.Textbox(label="Answer")
344
  understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
345
 
@@ -349,67 +370,79 @@ with gr.Blocks() as demo:
349
  examples=[
350
 
351
  [
 
352
  "What was the price of a barrel of oil in February 2020?",
353
  "images/LineChart.png"
354
  ],
355
 
356
  [
 
357
  "What is the average internet speed in Japan?",
358
  "images/BarChart.png"
359
  ],
360
 
361
  [
 
362
  "What is the cost of peanuts in Seoul?",
363
  "images/StackedBar.png"
364
  ],
365
 
366
- [
 
367
  "Which country has the lowest proportion of Gold medals?",
368
  "images/Stacked100.png"
369
  ],
370
 
371
  [
 
372
  "What is the approximate global smartphone market share of Samsung?",
373
  "images/PieChart.png"
374
  ],
375
 
376
- [
 
377
  "What distance have customers traveled in the taxi the most?",
378
  "images/Histogram.png"
379
  ],
380
 
381
  [
 
382
  "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
383
  "images/Scatterplot.png"
384
  ],
385
 
386
  [
 
387
  "What was the average price of pount of coffee beans in October 2019?",
388
  "images/AreaChart.png"
389
  ],
390
 
391
  [
 
392
  "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
393
  "images/StackedArea.png"
394
  ],
395
 
396
  [
 
397
  "Which city's metro system has the largest number of stations?",
398
  "images/BubbleChart.png"
399
  ],
400
 
401
  [
 
402
  "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
403
  "images/Choropleth_New.png"
404
  ],
405
 
406
  [
 
407
  "True/False: eBay is nested in the Software category.",
408
  "images/TreeMap.png"
409
  ]
410
 
411
  ],
412
- inputs=[question_input, image_input],
413
  )
414
 
415
 
@@ -418,7 +451,7 @@ with gr.Blocks() as demo:
418
  understanding_button.click(
419
  multimodal_understanding,
420
  inputs=[model_selector, activation_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
421
- visualization_layers_min, visualization_layers_max, focus, response_type],
422
  outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
423
  )
424
 
 
9
  import numpy as np
10
  import matplotlib.pyplot as plt
11
  import gc
12
+ import os
13
  import spaces
14
  from PIL import Image
15
 
 
54
  activation_map_method,
55
  visual_pooling_method,
56
  image, question, seed, top_p, temperature, target_token_idx,
57
+ visualization_layer_min, visualization_layer_max, focus, response_type, chart_type):
58
  # Clear CUDA cache before generating
59
  gc.collect()
60
  if torch.cuda.is_available():
 
76
  if activation_map_method == "GradCAM":
77
  # Generate Grad-CAM
78
  all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
79
+
80
+ if visualization_layer_min != visualization_layer_max:
81
  target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
82
  else:
83
  target_layers = [all_layers[visualization_layer_min-1]]
 
112
 
113
  input_ids = prepare_inputs.input_ids[0].cpu().tolist()
114
  input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
 
 
 
 
 
 
115
 
116
  if activation_map_method == "GradCAM":
117
  # target_layers = vl_gpt.vision_model.vision_tower.blocks
 
119
  all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
120
  else:
121
  all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
122
+
123
+ print("layer values:", visualization_layer_min, visualization_layer_max)
124
+ if visualization_layer_min != visualization_layer_max:
125
+ print("multi layers")
126
+ target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max]
127
  else:
128
+ print("single layer")
129
  target_layers = [all_layers[visualization_layer_min-1]]
130
+
131
 
132
  if model_name.split('-')[0] == "Janus":
133
  gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
 
165
 
166
  cam.append(cam_i)
167
 
168
+ # Collect Results
169
+ RESULTS_ROOT = "./results"
170
+ FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}"
171
+ os.makedirs(FILES_ROOT, exist_ok=True)
172
+ if focus == "Visual Encoder":
173
+ cam[0].save(f"{FILES_ROOT}/{visual_pooling_method}.png")
174
+ else:
175
+ for i, cam_p in enumerate(cam):
176
+ cam_p.save(f"{FILES_ROOT}/{i}.png")
177
+
178
+ with open(f"{FILES_ROOT}/input_text_decoded.txt", "w") as f:
179
+ f.write(input_text_decoded)
180
+ f.close()
181
+
182
+ with open(f"{FILES_ROOT}/answer.txt", "w") as f:
183
+ f.write(answer)
184
+ f.close()
185
+
186
+
187
+
188
  return answer, cam, input_text_decoded
189
 
190
 
 
255
 
256
  res = (
257
  gr.Dropdown(choices=["Visualization only", "answer + visualization"], value="answer + visualization", label="response_type"),
258
+ gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers min"),
259
+ gr.Slider(minimum=1, maximum=language_model_best_layer, value=language_model_best_layer, step=1, label="visualization layers max"),
260
  gr.Dropdown(choices=["Language Model"], value="Language Model", label="focus"),
261
  gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
262
  )
 
311
  activation_map_output = gr.Gallery(label="activation Map", height=300, columns=1)
312
 
313
  with gr.Column():
314
+ model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
315
  response_type = gr.Dropdown(choices=["Visualization only"], value="Visualization only", label="response_type")
316
  focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
317
  activation_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="activation map type")
 
359
 
360
 
361
 
362
+ understanding_button = gr.Button("Submit")
363
+ chart_type = gr.Textbox(label="Chart Type")
364
  understanding_output = gr.Textbox(label="Answer")
365
  understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
366
 
 
370
  examples=[
371
 
372
  [
373
+ "LineChart",
374
  "What was the price of a barrel of oil in February 2020?",
375
  "images/LineChart.png"
376
  ],
377
 
378
  [
379
+ "BarChart",
380
  "What is the average internet speed in Japan?",
381
  "images/BarChart.png"
382
  ],
383
 
384
  [
385
+ "StackedBar",
386
  "What is the cost of peanuts in Seoul?",
387
  "images/StackedBar.png"
388
  ],
389
 
390
+ [
391
+ "100%StackedBar",
392
  "Which country has the lowest proportion of Gold medals?",
393
  "images/Stacked100.png"
394
  ],
395
 
396
  [
397
+ "PieChart",
398
  "What is the approximate global smartphone market share of Samsung?",
399
  "images/PieChart.png"
400
  ],
401
 
402
+ [
403
+ "Histogram",
404
  "What distance have customers traveled in the taxi the most?",
405
  "images/Histogram.png"
406
  ],
407
 
408
  [
409
+ "Scatterplot",
410
  "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
411
  "images/Scatterplot.png"
412
  ],
413
 
414
  [
415
+ "AreaChart",
416
  "What was the average price of pount of coffee beans in October 2019?",
417
  "images/AreaChart.png"
418
  ],
419
 
420
  [
421
+ "StackedArea",
422
  "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
423
  "images/StackedArea.png"
424
  ],
425
 
426
  [
427
+ "BubbleChart",
428
  "Which city's metro system has the largest number of stations?",
429
  "images/BubbleChart.png"
430
  ],
431
 
432
  [
433
+ "Choropleth",
434
  "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
435
  "images/Choropleth_New.png"
436
  ],
437
 
438
  [
439
+ "TreeMap",
440
  "True/False: eBay is nested in the Software category.",
441
  "images/TreeMap.png"
442
  ]
443
 
444
  ],
445
+ inputs=[chart_type, question_input, image_input],
446
  )
447
 
448
 
 
451
  understanding_button.click(
452
  multimodal_understanding,
453
  inputs=[model_selector, activation_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
454
+ visualization_layers_min, visualization_layers_max, focus, response_type, chart_type],
455
  outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
456
  )
457
 
demo/cam.py CHANGED
@@ -247,14 +247,10 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
247
 
248
  act = act.mean(dim=1)
249
 
250
-
251
  # Compute mean of gradients
252
  print("grad_shape:", grad.shape)
253
  grad_weights = F.relu(grad.mean(dim=1))
254
 
255
-
256
- # cam, _ = (act * grad_weights).max(dim=-1)
257
- # cam = act * grad_weights
258
  cam = act * grad_weights
259
  print(cam.shape)
260
 
@@ -266,17 +262,12 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
266
 
267
  # Normalize
268
  cam_sum = F.relu(cam_sum)
269
- # cam_sum = cam_sum - cam_sum.min()
270
- # cam_sum = cam_sum / cam_sum.max()
271
 
272
  # thresholding
273
  cam_sum = cam_sum.to(torch.float32)
274
  percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
275
  cam_sum[cam_sum < percentile] = 0
276
 
277
- # Reshape
278
- # if visual_pooling_method == "CLS":
279
- # cam_sum = cam_sum[0, 1:]
280
 
281
  # cam_sum shape: [1, seq_len, seq_len]
282
  cam_sum_lst = []
@@ -300,15 +291,6 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
300
 
301
  return cam_sum_lst, grid_size, start
302
 
303
- # Aggregate activations and gradients from ALL layers
304
-
305
-
306
-
307
-
308
-
309
-
310
-
311
-
312
 
313
 
314
  class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
@@ -376,7 +358,6 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
376
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
377
  cam_sum = None
378
 
379
- # Ver 2
380
  for act, grad in zip(self.activations, self.gradients):
381
 
382
  print("act shape", act.shape)
@@ -397,13 +378,6 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
397
  cam_sum = F.relu(cam_sum)
398
  cam_sum = cam_sum.to(torch.float32)
399
 
400
- # thresholding
401
- # percentile = torch.quantile(cam_sum, 0.4) # Adjust threshold dynamically
402
- # cam_sum[cam_sum < percentile] = 0
403
-
404
- # Reshape
405
- # if visual_pooling_method == "CLS":
406
- # cam_sum = cam_sum[0, 1:]
407
 
408
  # cam_sum shape: [1, seq_len, seq_len]
409
  cam_sum_lst = []
@@ -412,7 +386,7 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
412
  for i in range(start_idx, cam_sum_raw.shape[1]):
413
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
414
 
415
- cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
416
  print("cam_sum shape: ", cam_sum.shape)
417
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
418
  grid_size = int(num_patches ** 0.5)
@@ -430,19 +404,6 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
430
 
431
 
432
 
433
-
434
-
435
-
436
-
437
-
438
-
439
-
440
-
441
-
442
-
443
-
444
-
445
-
446
  class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
447
  def __init__(self, model, target_layers):
448
  self.target_layers = target_layers
@@ -489,7 +450,6 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
489
  self.model.zero_grad()
490
  # print(outputs_raw)
491
  loss = outputs_raw.logits.max(dim=-1).values.sum()
492
-
493
  loss.backward()
494
 
495
  # get image masks
@@ -507,75 +467,54 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
507
  # Aggregate activations and gradients from ALL layers
508
  self.activations = [layer.get_attn_map() for layer in self.target_layers]
509
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
510
- cam_sum = None
511
- # Ver 1
512
- # for act, grad in zip(self.activations, self.gradients):
513
- # # act = torch.sigmoid(act)
514
- # print("act:", act)
515
- # print(len(act))
516
- # print("act_shape:", act.shape)
517
- # # print("act1_shape:", act[1].shape)
518
-
519
- # act = F.relu(act.mean(dim=1))
520
-
521
-
522
- # # Compute mean of gradients
523
- # print("grad:", grad)
524
- # print(len(grad))
525
- # print("grad_shape:", grad.shape)
526
- # grad_weights = grad.mean(dim=1)
527
-
528
- # print("act shape", act.shape)
529
- # print("grad_weights shape", grad_weights.shape)
530
-
531
- # cam = act * grad_weights
532
- # # cam = act
533
- # print(cam.shape)
534
-
535
- # # Sum across all layers
536
- # if cam_sum is None:
537
- # cam_sum = cam
538
- # else:
539
- # cam_sum += cam
540
 
 
 
541
  # Ver 2
542
  for act, grad in zip(self.activations, self.gradients):
543
 
544
  print("act shape", act.shape)
545
  print("grad shape", grad.shape)
546
-
547
  grad = F.relu(grad)
548
 
549
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
550
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
 
 
551
 
552
- # Sum across all layers
553
- if cam_sum is None:
554
- cam_sum = cam
555
- else:
556
- cam_sum += cam
557
 
558
- cam_sum = F.relu(cam_sum)
559
- cam_sum = cam_sum.to(torch.float32)
560
-
561
- # cam_sum shape: [1, seq_len, seq_len]
562
  cam_sum_lst = []
563
- cam_sum_raw = cam_sum
564
  start_idx = last + 1
565
- for i in range(start_idx, cam_sum_raw.shape[1]):
566
- cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
567
- # cam_sum_min = cam_sum.min()
568
- # cam_sum_max = cam_sum.max()
569
- # cam_sum = (cam_sum - cam_sum_min) / (cam_sum_max - cam_sum_min)
570
- cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
571
- print("cam_sum shape: ", cam_sum.shape)
572
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
573
- grid_size = int(num_patches ** 0.5)
574
- print(f"Detected grid size: {grid_size}x{grid_size}")
575
 
576
- # Fix the reshaping step dynamically
577
-
578
- cam_sum = cam_sum.view(grid_size, grid_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
580
  cam_sum_lst.append(cam_sum)
581
 
@@ -604,7 +543,7 @@ def generate_gradcam(
604
  alpha=0.5,
605
  colormap=cv2.COLORMAP_JET,
606
  aggregation='mean',
607
- normalize=True
608
  ):
609
  """
610
  Generates a Grad-CAM heatmap overlay on top of the input image.
 
247
 
248
  act = act.mean(dim=1)
249
 
 
250
  # Compute mean of gradients
251
  print("grad_shape:", grad.shape)
252
  grad_weights = F.relu(grad.mean(dim=1))
253
 
 
 
 
254
  cam = act * grad_weights
255
  print(cam.shape)
256
 
 
262
 
263
  # Normalize
264
  cam_sum = F.relu(cam_sum)
 
 
265
 
266
  # thresholding
267
  cam_sum = cam_sum.to(torch.float32)
268
  percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
269
  cam_sum[cam_sum < percentile] = 0
270
 
 
 
 
271
 
272
  # cam_sum shape: [1, seq_len, seq_len]
273
  cam_sum_lst = []
 
291
 
292
  return cam_sum_lst, grid_size, start
293
 
 
 
 
 
 
 
 
 
 
294
 
295
 
296
  class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
 
358
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
359
  cam_sum = None
360
 
 
361
  for act, grad in zip(self.activations, self.gradients):
362
 
363
  print("act shape", act.shape)
 
378
  cam_sum = F.relu(cam_sum)
379
  cam_sum = cam_sum.to(torch.float32)
380
 
 
 
 
 
 
 
 
381
 
382
  # cam_sum shape: [1, seq_len, seq_len]
383
  cam_sum_lst = []
 
386
  for i in range(start_idx, cam_sum_raw.shape[1]):
387
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
388
 
389
+ cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
390
  print("cam_sum shape: ", cam_sum.shape)
391
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
392
  grid_size = int(num_patches ** 0.5)
 
404
 
405
 
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
408
  def __init__(self, model, target_layers):
409
  self.target_layers = target_layers
 
450
  self.model.zero_grad()
451
  # print(outputs_raw)
452
  loss = outputs_raw.logits.max(dim=-1).values.sum()
 
453
  loss.backward()
454
 
455
  # get image masks
 
467
  # Aggregate activations and gradients from ALL layers
468
  self.activations = [layer.get_attn_map() for layer in self.target_layers]
469
  self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
470
+ print(f"layers shape: {len(self.target_layers)}")
471
+ print("activations & gradients shape", len(self.activations), len(self.gradients))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
+ cams = []
474
+
475
  # Ver 2
476
  for act, grad in zip(self.activations, self.gradients):
477
 
478
  print("act shape", act.shape)
479
  print("grad shape", grad.shape)
480
+
481
  grad = F.relu(grad)
482
 
483
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
484
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
485
+ cam = cam.to(torch.float32).detach().cpu()
486
+ cams.append(cam)
487
 
488
+ # cam_sum = F.relu(cam_sum)
489
+ # cam_sum = cam_sum.to(torch.float32)
 
 
 
490
 
491
+ # cams shape: [layers, 1, seq_len, seq_len]
 
 
 
492
  cam_sum_lst = []
493
+
494
  start_idx = last + 1
495
+ for i in range(start_idx, cams[0].shape[1]):
496
+ cam_sum = None
497
+ for layer, cam_l in enumerate(cams):
498
+ cam_l_i = cam_l[0, i, :] # shape: [1: seq_len]
 
 
 
 
 
 
499
 
500
+ cam_l_i = cam_l_i[image_mask].unsqueeze(0) # shape: [1, img_seq_len]
501
+ # print(f"layer: {layer}, token index: {i}")
502
+ # print("cam_sum shape: ", cam_l_i.shape)
503
+ num_patches = cam_l_i.shape[-1] # Last dimension of CAM output
504
+ grid_size = int(num_patches ** 0.5)
505
+ # print(f"Detected grid size: {grid_size}x{grid_size}")
506
+
507
+ # Fix the reshaping step dynamically
508
+ cam_reshaped = cam_l_i.view(grid_size, grid_size)
509
+ # print(f"max: {cam_reshaped.max()}, min: {cam_reshaped.min()}")
510
+ cam_normalized = (cam_reshaped - cam_reshaped.min()) / (cam_reshaped.max() - cam_reshaped.min())
511
+ if cam_sum == None:
512
+ cam_sum = cam_normalized
513
+ else:
514
+ cam_sum += cam_normalized
515
+ # print(f"normalized: max: {cam_normalized.max()}, min: {cam_normalized.min()}")
516
+
517
+ # print(f"sum: max: {cam_sum.max()}, min: {cam_sum.min()}")
518
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
519
  cam_sum_lst.append(cam_sum)
520
 
 
543
  alpha=0.5,
544
  colormap=cv2.COLORMAP_JET,
545
  aggregation='mean',
546
+ normalize=False
547
  ):
548
  """
549
  Generates a Grad-CAM heatmap overlay on top of the input image.