AustingDong commited on
Commit
585504b
·
1 Parent(s): 39961f4

Update cam.py

Browse files
Files changed (1) hide show
  1. demo/cam.py +1 -20
demo/cam.py CHANGED
@@ -20,22 +20,18 @@ class AttentionGuidedCAM:
20
  self._register_hooks()
21
 
22
  def _register_hooks(self):
23
- """ Registers hooks to extract activations and gradients from ALL attention layers. """
24
  for layer in self.target_layers:
25
  self.hooks.append(layer.register_forward_hook(self._forward_hook))
26
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
27
 
28
  def _forward_hook(self, module, input, output):
29
- """ Stores attention maps (before softmax) """
30
  self.activations.append(output)
31
 
32
  def _backward_hook(self, module, grad_in, grad_out):
33
- """ Stores gradients """
34
  self.gradients.append(grad_out[0])
35
 
36
 
37
  def remove_hooks(self):
38
- """ Remove hooks after usage. """
39
  for hook in self.hooks:
40
  hook.remove()
41
 
@@ -153,7 +149,6 @@ class AttentionGuidedCAMJanus(AttentionGuidedCAM):
153
 
154
  @spaces.GPU(duration=120)
155
  def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
156
- """ Generates Grad-CAM heatmap for ViT. """
157
 
158
 
159
  # Forward pass
@@ -338,7 +333,6 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
338
 
339
  @spaces.GPU(duration=120)
340
  def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
341
- """ Generates Grad-CAM heatmap for ViT. """
342
 
343
  # Forward pass
344
  outputs_raw = self.model(**inputs)
@@ -401,16 +395,12 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
401
  start_idx = last + 1
402
  for i in range(start_idx, cam_sum_raw.shape[1]):
403
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
404
- # cam_sum_min = cam_sum.min()
405
- # cam_sum_max = cam_sum.max()
406
- # cam_sum = (cam_sum - cam_sum_min) / (cam_sum_max - cam_sum_min)
407
  cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
408
  print("cam_sum shape: ", cam_sum.shape)
409
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
410
  grid_size = int(num_patches ** 0.5)
411
  print(f"Detected grid size: {grid_size}x{grid_size}")
412
-
413
- # Fix the reshaping step dynamically
414
 
415
  cam_sum = cam_sum.view(grid_size, grid_size)
416
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
@@ -468,7 +458,6 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
468
 
469
  @spaces.GPU(duration=120)
470
  def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
471
- """ Generates Grad-CAM heatmap for ViT. """
472
 
473
  # Forward pass
474
  outputs_raw = self.model(**inputs)
@@ -545,14 +534,6 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
545
  cam_sum = F.relu(cam_sum)
546
  cam_sum = cam_sum.to(torch.float32)
547
 
548
- # thresholding
549
- # percentile = torch.quantile(cam_sum, 0.4) # Adjust threshold dynamically
550
- # cam_sum[cam_sum < percentile] = 0
551
-
552
- # Reshape
553
- # if visual_pooling_method == "CLS":
554
- # cam_sum = cam_sum[0, 1:]
555
-
556
  # cam_sum shape: [1, seq_len, seq_len]
557
  cam_sum_lst = []
558
  cam_sum_raw = cam_sum
 
20
  self._register_hooks()
21
 
22
  def _register_hooks(self):
 
23
  for layer in self.target_layers:
24
  self.hooks.append(layer.register_forward_hook(self._forward_hook))
25
  self.hooks.append(layer.register_backward_hook(self._backward_hook))
26
 
27
  def _forward_hook(self, module, input, output):
 
28
  self.activations.append(output)
29
 
30
  def _backward_hook(self, module, grad_in, grad_out):
 
31
  self.gradients.append(grad_out[0])
32
 
33
 
34
  def remove_hooks(self):
 
35
  for hook in self.hooks:
36
  hook.remove()
37
 
 
149
 
150
  @spaces.GPU(duration=120)
151
  def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
 
152
 
153
 
154
  # Forward pass
 
333
 
334
  @spaces.GPU(duration=120)
335
  def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
 
336
 
337
  # Forward pass
338
  outputs_raw = self.model(**inputs)
 
395
  start_idx = last + 1
396
  for i in range(start_idx, cam_sum_raw.shape[1]):
397
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
398
+
 
 
399
  cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
400
  print("cam_sum shape: ", cam_sum.shape)
401
  num_patches = cam_sum.shape[-1] # Last dimension of CAM output
402
  grid_size = int(num_patches ** 0.5)
403
  print(f"Detected grid size: {grid_size}x{grid_size}")
 
 
404
 
405
  cam_sum = cam_sum.view(grid_size, grid_size)
406
  cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
 
458
 
459
  @spaces.GPU(duration=120)
460
  def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
 
461
 
462
  # Forward pass
463
  outputs_raw = self.model(**inputs)
 
534
  cam_sum = F.relu(cam_sum)
535
  cam_sum = cam_sum.to(torch.float32)
536
 
 
 
 
 
 
 
 
 
537
  # cam_sum shape: [1, seq_len, seq_len]
538
  cam_sum_lst = []
539
  cam_sum_raw = cam_sum