AustingDong commited on
Commit
e788822
·
1 Parent(s): 8235fd2

finished baseline

Browse files
Files changed (2) hide show
  1. demo/cam.py +6 -1
  2. demo/model_utils.py +1 -1
demo/cam.py CHANGED
@@ -535,7 +535,11 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
535
  elif focus == "Language Model":
536
  self.model.zero_grad()
537
  # print(outputs_raw)
538
- loss = outputs_raw.logits.max(dim=-1).values.sum()
 
 
 
 
539
  loss.backward()
540
 
541
 
@@ -556,6 +560,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
556
 
557
  grad = F.relu(grad)
558
 
 
559
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
560
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
561
  cam = cam.to(torch.float32).detach().cpu()
 
535
  elif focus == "Language Model":
536
  self.model.zero_grad()
537
  # print(outputs_raw)
538
+ # loss = outputs_raw.logits.max(dim=-1).values.sum()
539
+ if class_idx == -1:
540
+ loss = outputs_raw.logits.max(dim=-1).values.sum()
541
+ else:
542
+ loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
543
  loss.backward()
544
 
545
 
 
560
 
561
  grad = F.relu(grad)
562
 
563
+ # cam = grad
564
  cam = act * grad # shape: [1, heads, seq_len, seq_len]
565
  cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
566
  cam = cam.to(torch.float32).detach().cpu()
demo/model_utils.py CHANGED
@@ -204,7 +204,7 @@ class ChartGemma_Utils(Model_Utils):
204
  self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
205
  model_path,
206
  torch_dtype=torch.float16,
207
- attn_implementation="eager",
208
  output_attentions=True
209
  )
210
  self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
 
204
  self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
205
  model_path,
206
  torch_dtype=torch.float16,
207
+ attn_implementation="sdpa",
208
  output_attentions=True
209
  )
210
  self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)