AustingDong
commited on
Commit
·
217eab6
1
Parent(s):
f826f5d
modified saving and outputs
Browse files- app.py +23 -6
- demo/cam.py +1 -1
- demo/visualization.py +17 -14
app.py
CHANGED
@@ -58,7 +58,7 @@ def multimodal_understanding(model_type,
|
|
58 |
activation_map_method,
|
59 |
visual_method,
|
60 |
image, question, seed, top_p, temperature, target_token_idx,
|
61 |
-
visualization_layer_min, visualization_layer_max, focus, response_type, chart_type, accumulate_method):
|
62 |
# Clear CUDA cache before generating
|
63 |
gc.collect()
|
64 |
if torch.cuda.is_available():
|
@@ -191,7 +191,7 @@ def multimodal_understanding(model_type,
|
|
191 |
|
192 |
# Collect Results
|
193 |
RESULTS_ROOT = "./results"
|
194 |
-
FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{visual_method}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}/{'all_tokens' if target_token_idx == -1 else f'--{input_ids_decoded[start + target_token_idx]}--'}"
|
195 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
196 |
|
197 |
for i, cam_p in enumerate(cam):
|
@@ -350,7 +350,19 @@ def focus_change(focus):
|
|
350 |
return res
|
351 |
|
352 |
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
|
356 |
with gr.Blocks() as demo:
|
@@ -368,6 +380,7 @@ with gr.Blocks() as demo:
|
|
368 |
|
369 |
with gr.Column():
|
370 |
model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
|
|
|
371 |
question_input = gr.Textbox(label="Input Prompt")
|
372 |
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
373 |
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
@@ -422,10 +435,14 @@ with gr.Blocks() as demo:
|
|
422 |
|
423 |
examples_inpainting = gr.Examples(
|
424 |
label="Multimodal Understanding examples",
|
425 |
-
|
426 |
-
examples=VLAT_questions,
|
427 |
inputs=[chart_type, question_input, image_input],
|
428 |
)
|
|
|
|
|
|
|
|
|
|
|
429 |
|
430 |
|
431 |
|
@@ -433,7 +450,7 @@ with gr.Blocks() as demo:
|
|
433 |
understanding_button.click(
|
434 |
multimodal_understanding,
|
435 |
inputs=[model_selector, activation_map_method, visual_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
|
436 |
-
visualization_layers_min, visualization_layers_max, focus, response_type, chart_type, accumulate_method],
|
437 |
outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
|
438 |
)
|
439 |
|
|
|
58 |
activation_map_method,
|
59 |
visual_method,
|
60 |
image, question, seed, top_p, temperature, target_token_idx,
|
61 |
+
visualization_layer_min, visualization_layer_max, focus, response_type, chart_type, accumulate_method, test_selector):
|
62 |
# Clear CUDA cache before generating
|
63 |
gc.collect()
|
64 |
if torch.cuda.is_available():
|
|
|
191 |
|
192 |
# Collect Results
|
193 |
RESULTS_ROOT = "./results"
|
194 |
+
FILES_ROOT = f"{RESULTS_ROOT}/{model_name}/{focus}/{visual_method}/{test_selector}/{chart_type}/layer{visualization_layer_min}-{visualization_layer_max}/{'all_tokens' if target_token_idx == -1 else f'--{input_ids_decoded[start + target_token_idx]}--'}"
|
195 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
196 |
|
197 |
for i, cam_p in enumerate(cam):
|
|
|
350 |
return res
|
351 |
|
352 |
|
353 |
+
def test_change(test_selector):
|
354 |
+
if test_selector == "mini-VLAT":
|
355 |
+
return gr.Dataset(
|
356 |
+
samples=mini_VLAT_questions,
|
357 |
+
)
|
358 |
+
elif test_selector == "VLAT":
|
359 |
+
return gr.Dataset(
|
360 |
+
samples=VLAT_questions,
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
return gr.Dataset(
|
364 |
+
samples=VLAT_old_questions,
|
365 |
+
)
|
366 |
|
367 |
|
368 |
with gr.Blocks() as demo:
|
|
|
380 |
|
381 |
with gr.Column():
|
382 |
model_selector = gr.Dropdown(choices=["Clip", "ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="Clip", label="model")
|
383 |
+
test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old"], value="mini-VLAT", label="test")
|
384 |
question_input = gr.Textbox(label="Input Prompt")
|
385 |
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
386 |
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
|
|
435 |
|
436 |
examples_inpainting = gr.Examples(
|
437 |
label="Multimodal Understanding examples",
|
438 |
+
examples=mini_VLAT_questions,
|
|
|
439 |
inputs=[chart_type, question_input, image_input],
|
440 |
)
|
441 |
+
|
442 |
+
test_selector.change(
|
443 |
+
fn=test_change,
|
444 |
+
inputs=test_selector,
|
445 |
+
outputs=examples_inpainting.dataset)
|
446 |
|
447 |
|
448 |
|
|
|
450 |
understanding_button.click(
|
451 |
multimodal_understanding,
|
452 |
inputs=[model_selector, activation_map_method, visual_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
|
453 |
+
visualization_layers_min, visualization_layers_max, focus, response_type, chart_type, accumulate_method, test_selector],
|
454 |
outputs=[understanding_output, activation_map_output, understanding_target_token_decoded_output]
|
455 |
)
|
456 |
|
demo/cam.py
CHANGED
@@ -534,7 +534,7 @@ class AttentionGuidedCAMChartGemma(AttentionGuidedCAM):
|
|
534 |
|
535 |
elif focus == "Language Model":
|
536 |
self.model.zero_grad()
|
537 |
-
|
538 |
# loss = outputs_raw.logits.max(dim=-1).values.sum()
|
539 |
if class_idx == -1:
|
540 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
|
|
534 |
|
535 |
elif focus == "Language Model":
|
536 |
self.model.zero_grad()
|
537 |
+
print("logits shape:", outputs_raw.logits.shape)
|
538 |
# loss = outputs_raw.logits.max(dim=-1).values.sum()
|
539 |
if class_idx == -1:
|
540 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
demo/visualization.py
CHANGED
@@ -82,8 +82,8 @@ class Visualization:
|
|
82 |
|
83 |
grad_weights = grad.mean(dim=-1, keepdim=True)
|
84 |
|
85 |
-
print("act shape", act.shape)
|
86 |
-
print("grad_weights shape", grad_weights.shape)
|
87 |
|
88 |
# cam = (act * grad_weights).sum(dim=-1)
|
89 |
cam, _ = (act * grad_weights).max(dim=-1)
|
@@ -132,8 +132,8 @@ class Visualization:
|
|
132 |
|
133 |
cams = []
|
134 |
for act, grad in zip(self.activations, self.gradients):
|
135 |
-
print("act shape", act.shape)
|
136 |
-
print("grad shape", grad.shape)
|
137 |
|
138 |
grad = F.relu(grad)
|
139 |
|
@@ -160,7 +160,7 @@ class Visualization:
|
|
160 |
|
161 |
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
162 |
grid_size = int(num_patches ** 0.5)
|
163 |
-
print(f"Detected grid size: {grid_size}x{grid_size}")
|
164 |
cam_sum = cam_sum.view(grid_size, grid_size).detach()
|
165 |
|
166 |
# Normalize
|
@@ -184,10 +184,10 @@ class Visualization:
|
|
184 |
for i in range(start, cam_sum_raw.shape[1]):
|
185 |
cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
|
186 |
cam_sum = cam_sum[images_seq_mask].unsqueeze(0) # shape: [1, img_seq_len]
|
187 |
-
print("cam_sum shape: ", cam_sum.shape)
|
188 |
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
189 |
grid_size = int(num_patches ** 0.5)
|
190 |
-
print(f"Detected grid size: {grid_size}x{grid_size}")
|
191 |
|
192 |
cam_sum = cam_sum.view(grid_size, grid_size)
|
193 |
if normalize:
|
@@ -418,6 +418,7 @@ class VisualizationChartGemma(Visualization):
|
|
418 |
|
419 |
elif focus == "Language Model":
|
420 |
self.model.zero_grad()
|
|
|
421 |
if target_token_idx == -1:
|
422 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
423 |
else:
|
@@ -486,15 +487,17 @@ def generate_gradcam(
|
|
486 |
Generates a Grad-CAM heatmap overlay on top of the input image.
|
487 |
|
488 |
Parameters:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
495 |
|
496 |
Returns:
|
497 |
-
|
498 |
"""
|
499 |
# print("Generating Grad-CAM with shape:", cam.shape)
|
500 |
|
|
|
82 |
|
83 |
grad_weights = grad.mean(dim=-1, keepdim=True)
|
84 |
|
85 |
+
# print("act shape", act.shape)
|
86 |
+
# print("grad_weights shape", grad_weights.shape)
|
87 |
|
88 |
# cam = (act * grad_weights).sum(dim=-1)
|
89 |
cam, _ = (act * grad_weights).max(dim=-1)
|
|
|
132 |
|
133 |
cams = []
|
134 |
for act, grad in zip(self.activations, self.gradients):
|
135 |
+
# print("act shape", act.shape)
|
136 |
+
# print("grad shape", grad.shape)
|
137 |
|
138 |
grad = F.relu(grad)
|
139 |
|
|
|
160 |
|
161 |
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
162 |
grid_size = int(num_patches ** 0.5)
|
163 |
+
# print(f"Detected grid size: {grid_size}x{grid_size}")
|
164 |
cam_sum = cam_sum.view(grid_size, grid_size).detach()
|
165 |
|
166 |
# Normalize
|
|
|
184 |
for i in range(start, cam_sum_raw.shape[1]):
|
185 |
cam_sum = cam_sum_raw[:, i, :] # shape: [1: seq_len]
|
186 |
cam_sum = cam_sum[images_seq_mask].unsqueeze(0) # shape: [1, img_seq_len]
|
187 |
+
# print("cam_sum shape: ", cam_sum.shape)
|
188 |
num_patches = cam_sum.shape[-1] # Last dimension of CAM output
|
189 |
grid_size = int(num_patches ** 0.5)
|
190 |
+
# print(f"Detected grid size: {grid_size}x{grid_size}")
|
191 |
|
192 |
cam_sum = cam_sum.view(grid_size, grid_size)
|
193 |
if normalize:
|
|
|
418 |
|
419 |
elif focus == "Language Model":
|
420 |
self.model.zero_grad()
|
421 |
+
print("logits shape:", outputs_raw.logits.shape)
|
422 |
if target_token_idx == -1:
|
423 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
424 |
else:
|
|
|
487 |
Generates a Grad-CAM heatmap overlay on top of the input image.
|
488 |
|
489 |
Parameters:
|
490 |
+
cam (torch.Tensor): A tensor of shape (C, H, W) representing the
|
491 |
+
intermediate activations or gradients at the target layer.
|
492 |
+
image (PIL.Image): The original image.
|
493 |
+
size (tuple): The desired size of the heatmap overlay (default (384, 384)).
|
494 |
+
alpha (float): The blending factor for the heatmap overlay (default 0.5).
|
495 |
+
colormap (int): OpenCV colormap to apply (default cv2.COLORMAP_JET).
|
496 |
+
aggregation (str): How to aggregate across channels; either 'mean' or 'sum'.
|
497 |
+
normalize (bool): Whether to normalize the heatmap (default False).
|
498 |
|
499 |
Returns:
|
500 |
+
PIL.Image: The image overlaid with the Grad-CAM heatmap.
|
501 |
"""
|
502 |
# print("Generating Grad-CAM with shape:", cam.shape)
|
503 |
|