AustingDong commited on
Commit
b9b9d9b
·
1 Parent(s): a907ad0

modified llava

Browse files
Files changed (2) hide show
  1. demo/cam.py +70 -59
  2. demo/model_utils.py +2 -1
demo/cam.py CHANGED
@@ -335,80 +335,91 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
335
  self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
336
 
337
  @spaces.GPU(duration=120)
338
- def generate_cam(self, input_tensor, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
339
  """ Generates Grad-CAM heatmap for ViT. """
340
 
341
-
342
  # Forward pass
343
- outputs_raw = self.model(**input_tensor)
344
 
345
- if focus == "Language Model":
346
- loss = self.target_layers[-1].attention_map.sum()
347
- self.model.zero_grad()
348
- loss.backward()
349
-
350
- self.activations = [layer.get_attn_map() for layer in self.target_layers]
351
- self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
352
 
353
- cam_sum = None
354
- for act, grad in zip(self.activations, self.gradients):
355
- # act = torch.sigmoid(act)
356
- print("act_shape:", act.shape)
357
-
358
- act = F.relu(act.mean(dim=1))
359
-
 
 
 
360
 
361
- # Compute mean of gradients
362
- print("grad_shape:", grad.shape)
363
- grad_weights = grad.mean(dim=1)
364
 
 
 
 
 
365
 
366
- # cam, _ = (act * grad_weights).max(dim=-1)
367
- # cam = act * grad_weights
368
- cam = act * grad_weights
369
- print(cam.shape)
370
 
371
- # Sum across all layers
372
- if cam_sum is None:
373
- cam_sum = cam
374
- else:
375
- cam_sum += cam
376
 
377
- # Normalize
378
- cam_sum = F.relu(cam_sum)
379
- # cam_sum = cam_sum - cam_sum.min()
380
- # cam_sum = cam_sum / cam_sum.max()
381
 
382
- # thresholding
383
- cam_sum = cam_sum.to(torch.float32)
384
- percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
385
- cam_sum[cam_sum < percentile] = 0
386
 
387
- # Reshape
388
- # if visual_pooling_method == "CLS":
389
- # cam_sum = cam_sum[0, 1:]
390
 
391
- # cam_sum shape: [1, seq_len, seq_len]
392
- cam_sum_lst = []
393
- cam_sum_raw = cam_sum
394
- grid_size = 32
395
- for i in range(512, cam_sum_raw.shape[1]):
396
- cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
397
- cam_sum = cam_sum[input_tensor.images_seq_mask].unsqueeze(0) # shape: [1, 576]
398
- print("cam_sum shape: ", cam_sum.shape)
399
- num_patches = cam_sum.shape[-1] # Last dimension of CAM output
400
- grid_size = int(num_patches ** 0.5)
401
- print(f"Detected grid size: {grid_size}x{grid_size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
- # Fix the reshaping step dynamically
404
-
405
- cam_sum = cam_sum.view(grid_size, grid_size)
406
- cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
407
- cam_sum = cam_sum.detach().to("cpu")
408
- cam_sum_lst.append(cam_sum)
409
 
410
 
411
- return cam_sum_lst, grid_size
412
 
413
 
414
 
@@ -546,7 +557,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
546
  # cam_sum shape: [1, seq_len, seq_len]
547
  cam_sum_lst = []
548
  cam_sum_raw = cam_sum
549
- start_idx = 1024
550
  for i in range(start_idx, cam_sum_raw.shape[1]):
551
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
552
  # cam_sum_min = cam_sum.min()
 
335
  self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
336
 
337
  @spaces.GPU(duration=120)
338
+ def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
339
  """ Generates Grad-CAM heatmap for ViT. """
340
 
 
341
  # Forward pass
342
+ outputs_raw = self.model(**inputs)
343
 
344
+ self.model.zero_grad()
345
+ print(outputs_raw)
346
+ # loss = self.target_layers[-1].attention_map.sum()
347
+ loss = outputs_raw.logits.max(dim=-1).values.sum()
348
+ loss.backward()
 
 
349
 
350
+ # get image masks
351
+ image_mask = []
352
+ last = 0
353
+ for i in range(inputs["input_ids"].shape[1]):
354
+ decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
355
+ if (decoded_token == "<image>"):
356
+ image_mask.append(True)
357
+ last = i
358
+ else:
359
+ image_mask.append(False)
360
 
 
 
 
361
 
362
+ # Aggregate activations and gradients from ALL layers
363
+ self.activations = [layer.get_attn_map() for layer in self.target_layers]
364
+ self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
365
+ cam_sum = None
366
 
367
+ # Ver 2
368
+ for act, grad in zip(self.activations, self.gradients):
 
 
369
 
370
+ print("act shape", act.shape)
371
+ print("grad shape", grad.shape)
 
 
 
372
 
373
+ act = F.relu(act)
374
+ grad = F.relu(grad)
 
 
375
 
 
 
 
 
376
 
377
+ cam = act * grad # shape: [1, heads, seq_len, seq_len]
378
+ cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
 
379
 
380
+ # Sum across all layers
381
+ if cam_sum is None:
382
+ cam_sum = cam
383
+ else:
384
+ cam_sum += cam
385
+
386
+ cam_sum = F.relu(cam_sum)
387
+ cam_sum = cam_sum.to(torch.float32)
388
+
389
+ # thresholding
390
+ # percentile = torch.quantile(cam_sum, 0.4) # Adjust threshold dynamically
391
+ # cam_sum[cam_sum < percentile] = 0
392
+
393
+ # Reshape
394
+ # if visual_pooling_method == "CLS":
395
+ # cam_sum = cam_sum[0, 1:]
396
+
397
+ # cam_sum shape: [1, seq_len, seq_len]
398
+ cam_sum_lst = []
399
+ cam_sum_raw = cam_sum
400
+ start_idx = last + 1
401
+ for i in range(start_idx, cam_sum_raw.shape[1]):
402
+ cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
403
+ # cam_sum_min = cam_sum.min()
404
+ # cam_sum_max = cam_sum.max()
405
+ # cam_sum = (cam_sum - cam_sum_min) / (cam_sum_max - cam_sum_min)
406
+ cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
407
+ print("cam_sum shape: ", cam_sum.shape)
408
+ num_patches = cam_sum.shape[-1] # Last dimension of CAM output
409
+ grid_size = int(num_patches ** 0.5)
410
+ print(f"Detected grid size: {grid_size}x{grid_size}")
411
+
412
+ # Fix the reshaping step dynamically
413
+
414
+ cam_sum = cam_sum.view(grid_size, grid_size)
415
+ cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
416
+ cam_sum_lst.append(cam_sum)
417
+
418
+
419
+ return cam_sum_lst, grid_size
420
 
 
 
 
 
 
 
421
 
422
 
 
423
 
424
 
425
 
 
557
  # cam_sum shape: [1, seq_len, seq_len]
558
  cam_sum_lst = []
559
  cam_sum_raw = cam_sum
560
+ start_idx = last + 1
561
  for i in range(start_idx, cam_sum_raw.shape[1]):
562
  cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
563
  # cam_sum_min = cam_sum.min()
demo/model_utils.py CHANGED
@@ -119,7 +119,8 @@ class LLaVA_Utils(Model_Utils):
119
 
120
  def init_LLaVA(self):
121
 
122
- model_path = f"llava-hf/llava-1.5-7b-hf"
 
123
  config = AutoConfig.from_pretrained(model_path)
124
 
125
  self.vl_gpt = LlavaForConditionalGeneration.from_pretrained(model_path,
 
119
 
120
  def init_LLaVA(self):
121
 
122
+ # model_path = "llava-hf/llava-1.5-7b-hf"
123
+ model_path = "llava-hf/llava-v1.6-mistral-7b-hf"
124
  config = AutoConfig.from_pretrained(model_path)
125
 
126
  self.vl_gpt = LlavaForConditionalGeneration.from_pretrained(model_path,