AustingDong
commited on
Commit
·
e788822
1
Parent(s):
8235fd2
finished baseline
Browse files- demo/cam.py +6 -1
- demo/model_utils.py +1 -1
demo/cam.py
CHANGED
@@ -535,7 +535,11 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
535 |
elif focus == "Language Model":
|
536 |
self.model.zero_grad()
|
537 |
# print(outputs_raw)
|
538 |
-
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
|
|
|
|
|
|
|
|
539 |
loss.backward()
|
540 |
|
541 |
|
@@ -556,6 +560,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
556 |
|
557 |
grad = F.relu(grad)
|
558 |
|
|
|
559 |
cam = act * grad # shape: [1, heads, seq_len, seq_len]
|
560 |
cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
|
561 |
cam = cam.to(torch.float32).detach().cpu()
|
|
|
535 |
elif focus == "Language Model":
|
536 |
self.model.zero_grad()
|
537 |
# print(outputs_raw)
|
538 |
+
# loss = outputs_raw.logits.max(dim=-1).values.sum()
|
539 |
+
if class_idx == -1:
|
540 |
+
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
541 |
+
else:
|
542 |
+
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + class_idx]
|
543 |
loss.backward()
|
544 |
|
545 |
|
|
|
560 |
|
561 |
grad = F.relu(grad)
|
562 |
|
563 |
+
# cam = grad
|
564 |
cam = act * grad # shape: [1, heads, seq_len, seq_len]
|
565 |
cam = cam.sum(dim=1) # shape: [1, seq_len, seq_len]
|
566 |
cam = cam.to(torch.float32).detach().cpu()
|
demo/model_utils.py
CHANGED
@@ -204,7 +204,7 @@ class ChartGemma_Utils(Model_Utils):
|
|
204 |
self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
|
205 |
model_path,
|
206 |
torch_dtype=torch.float16,
|
207 |
-
attn_implementation="
|
208 |
output_attentions=True
|
209 |
)
|
210 |
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|
|
|
204 |
self.vl_gpt = PaliGemmaForConditionalGeneration.from_pretrained(
|
205 |
model_path,
|
206 |
torch_dtype=torch.float16,
|
207 |
+
attn_implementation="sdpa",
|
208 |
output_attentions=True
|
209 |
)
|
210 |
self.vl_gpt, self.dtype, self.cuda_device = set_dtype_device(self.vl_gpt)
|