AustingDong
commited on
Commit
·
b9b9d9b
1
Parent(s):
a907ad0
modified llava
Browse files- demo/cam.py +70 -59
- demo/model_utils.py +2 -1
demo/cam.py
CHANGED
@@ -335,80 +335,91 @@ class AttentionGuidedCAMLLaVA(AttentionGuidedCAM):
|
|
335 |
self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
|
336 |
|
337 |
@spaces.GPU(duration=120)
|
338 |
-
def generate_cam(self,
|
339 |
""" Generates Grad-CAM heatmap for ViT. """
|
340 |
|
341 |
-
|
342 |
# Forward pass
|
343 |
-
outputs_raw = self.model(**
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
self.activations = [layer.get_attn_map() for layer in self.target_layers]
|
351 |
-
self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
360 |
|
361 |
-
# Compute mean of gradients
|
362 |
-
print("grad_shape:", grad.shape)
|
363 |
-
grad_weights = grad.mean(dim=1)
|
364 |
|
|
|
|
|
|
|
|
|
365 |
|
366 |
-
|
367 |
-
|
368 |
-
cam = act * grad_weights
|
369 |
-
print(cam.shape)
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
cam_sum = cam
|
374 |
-
else:
|
375 |
-
cam_sum += cam
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
# cam_sum = cam_sum - cam_sum.min()
|
380 |
-
# cam_sum = cam_sum / cam_sum.max()
|
381 |
|
382 |
-
# thresholding
|
383 |
-
cam_sum = cam_sum.to(torch.float32)
|
384 |
-
percentile = torch.quantile(cam_sum, 0.2) # Adjust threshold dynamically
|
385 |
-
cam_sum[cam_sum < percentile] = 0
|
386 |
|
387 |
-
#
|
388 |
-
#
|
389 |
-
# cam_sum = cam_sum[0, 1:]
|
390 |
|
391 |
-
#
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
-
# Fix the reshaping step dynamically
|
404 |
-
|
405 |
-
cam_sum = cam_sum.view(grid_size, grid_size)
|
406 |
-
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
407 |
-
cam_sum = cam_sum.detach().to("cpu")
|
408 |
-
cam_sum_lst.append(cam_sum)
|
409 |
|
410 |
|
411 |
-
return cam_sum_lst, grid_size
|
412 |
|
413 |
|
414 |
|
@@ -546,7 +557,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
546 |
# cam_sum shape: [1, seq_len, seq_len]
|
547 |
cam_sum_lst = []
|
548 |
cam_sum_raw = cam_sum
|
549 |
-
start_idx =
|
550 |
for i in range(start_idx, cam_sum_raw.shape[1]):
|
551 |
cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
|
552 |
# cam_sum_min = cam_sum.min()
|
|
|
335 |
self.hooks.append(layer.register_forward_hook(self._forward_activate_hooks))
|
336 |
|
337 |
@spaces.GPU(duration=120)
|
338 |
+
def generate_cam(self, inputs, tokenizer, temperature, top_p, class_idx=None, visual_pooling_method="CLS", focus="Visual Encoder"):
|
339 |
""" Generates Grad-CAM heatmap for ViT. """
|
340 |
|
|
|
341 |
# Forward pass
|
342 |
+
outputs_raw = self.model(**inputs)
|
343 |
|
344 |
+
self.model.zero_grad()
|
345 |
+
print(outputs_raw)
|
346 |
+
# loss = self.target_layers[-1].attention_map.sum()
|
347 |
+
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
348 |
+
loss.backward()
|
|
|
|
|
349 |
|
350 |
+
# get image masks
|
351 |
+
image_mask = []
|
352 |
+
last = 0
|
353 |
+
for i in range(inputs["input_ids"].shape[1]):
|
354 |
+
decoded_token = tokenizer.decode(inputs["input_ids"][0][i].item())
|
355 |
+
if (decoded_token == "<image>"):
|
356 |
+
image_mask.append(True)
|
357 |
+
last = i
|
358 |
+
else:
|
359 |
+
image_mask.append(False)
|
360 |
|
|
|
|
|
|
|
361 |
|
362 |
+
# Aggregate activations and gradients from ALL layers
|
363 |
+
self.activations = [layer.get_attn_map() for layer in self.target_layers]
|
364 |
+
self.gradients = [layer.get_attn_gradients() for layer in self.target_layers]
|
365 |
+
cam_sum = None
|
366 |
|
367 |
+
# Ver 2
|
368 |
+
for act, grad in zip(self.activations, self.gradients):
|
|
|
|
|
369 |
|
370 |
+
print("act shape", act.shape)
|
371 |
+
print("grad shape", grad.shape)
|
|
|
|
|
|
|
372 |
|
373 |
+
act = F.relu(act)
|
374 |
+
grad = F.relu(grad)
|
|
|
|
|
375 |
|
|
|
|
|
|
|
|
|
376 |
|
377 |
+
cam = act * grad # shape: [1, heads, seq_len, seq_len]
|
378 |
+
cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
|
|
|
379 |
|
380 |
+
# Sum across all layers
|
381 |
+
if cam_sum is None:
|
382 |
+
cam_sum = cam
|
383 |
+
else:
|
384 |
+
cam_sum += cam
|
385 |
+
|
386 |
+
cam_sum = F.relu(cam_sum)
|
387 |
+
cam_sum = cam_sum.to(torch.float32)
|
388 |
+
|
389 |
+
# thresholding
|
390 |
+
# percentile = torch.quantile(cam_sum, 0.4) # Adjust threshold dynamically
|
391 |
+
# cam_sum[cam_sum < percentile] = 0
|
392 |
+
|
393 |
+
# Reshape
|
394 |
+
# if visual_pooling_method == "CLS":
|
395 |
+
# cam_sum = cam_sum[0, 1:]
|
396 |
+
|
397 |
+
# cam_sum shape: [1, seq_len, seq_len]
|
398 |
+
cam_sum_lst = []
|
399 |
+
cam_sum_raw = cam_sum
|
400 |
+
start_idx = last + 1
|
401 |
+
for i in range(start_idx, cam_sum_raw.shape[1]):
|
402 |
+
cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
|
403 |
+
# cam_sum_min = cam_sum.min()
|
404 |
+
# cam_sum_max = cam_sum.max()
|
405 |
+
# cam_sum = (cam_sum - cam_sum_min) / (cam_sum_max - cam_sum_min)
|
406 |
+
cam_sum = cam_sum[image_mask].unsqueeze(0) # shape: [1, 1024]
|
407 |
+
print("cam_sum shape: ", cam_sum.shape)
|
408 |
+
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
409 |
+
grid_size = int(num_patches ** 0.5)
|
410 |
+
print(f"Detected grid size: {grid_size}x{grid_size}")
|
411 |
+
|
412 |
+
# Fix the reshaping step dynamically
|
413 |
+
|
414 |
+
cam_sum = cam_sum.view(grid_size, grid_size)
|
415 |
+
cam_sum = (cam_sum - cam_sum.min()) / (cam_sum.max() - cam_sum.min())
|
416 |
+
cam_sum_lst.append(cam_sum)
|
417 |
+
|
418 |
+
|
419 |
+
return cam_sum_lst, grid_size
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
|
|
|
423 |
|
424 |
|
425 |
|
|
|
557 |
# cam_sum shape: [1, seq_len, seq_len]
|
558 |
cam_sum_lst = []
|
559 |
cam_sum_raw = cam_sum
|
560 |
+
start_idx = last + 1
|
561 |
for i in range(start_idx, cam_sum_raw.shape[1]):
|
562 |
cam_sum = cam_sum_raw[0, i, :] # shape: [1: seq_len]
|
563 |
# cam_sum_min = cam_sum.min()
|
demo/model_utils.py
CHANGED
@@ -119,7 +119,8 @@ class LLaVA_Utils(Model_Utils):
|
|
119 |
|
120 |
def init_LLaVA(self):
|
121 |
|
122 |
-
model_path =
|
|
|
123 |
config = AutoConfig.from_pretrained(model_path)
|
124 |
|
125 |
self.vl_gpt = LlavaForConditionalGeneration.from_pretrained(model_path,
|
|
|
119 |
|
120 |
def init_LLaVA(self):
|
121 |
|
122 |
+
# model_path = "llava-hf/llava-1.5-7b-hf"
|
123 |
+
model_path = "llava-hf/llava-v1.6-mistral-7b-hf"
|
124 |
config = AutoConfig.from_pretrained(model_path)
|
125 |
|
126 |
self.vl_gpt = LlavaForConditionalGeneration.from_pretrained(model_path,
|