AustingDong
init
1ca9e3b
raw
history blame
13.2 kB
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip
from demo.model_utils import Clip_Utils, Janus_Utils, add_title_to_image
import numpy as np
import matplotlib.pyplot as plt
import gc
from PIL import Image
model_seed = 42
torch.manual_seed(model_seed)
np.random.seed(model_seed)
torch.cuda.manual_seed(model_seed)
model_type = "Janus-1B"
janus_utils = Janus_Utils()
vl_gpt, tokenizer = janus_utils.init_Janus(model_type.split('-')[-1])
clip_utils = Clip_Utils()
clip_utils.init_Clip()
# @torch.inference_mode() # cancel inference, for gradcam
# @spaces.GPU(duration=120)
# Multimodal Understanding function
def multimodal_understanding(model_type,
saliency_map_method,
visual_pooling_method,
image, question, seed, top_p, temperature, target_token_idx,
visualization_layer_min, visualization_layer_max, focus):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
input_text_decoded = ""
if model_type == "Clip":
inputs = clip_utils.prepare_inputs([question], image)
if saliency_map_method == "GradCAM":
# Generate Grad-CAM
all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
if visualization_layers_min.value != visualization_layers_max.value:
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
else:
target_layers = [all_layers[visualization_layer_min-1]]
grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
cam = [generate_gradcam(cam, image, size=(224, 224))]
grad_cam.remove_hooks()
target_token_decoded = ""
answer = ""
elif model_type == "Janus-1B":
for param in vl_gpt.parameters():
param.requires_grad = True
prepare_inputs = janus_utils.prepare_inputs(question, image)
inputs_embeds = janus_utils.generate_inputs_embeddings(prepare_inputs)
outputs = janus_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
sequences = outputs.sequences.cpu().tolist()
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
attention_raw = outputs.attentions
print("answer generated")
input_ids = prepare_inputs.input_ids[0].cpu().tolist()
input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
start=620
if saliency_map_method == "GradCAM":
# target_layers = vl_gpt.vision_model.vision_tower.blocks
if focus == "Visual Encoder":
all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
else:
all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
if visualization_layers_min.value != visualization_layers_max.value:
target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
else:
target_layers = [all_layers[visualization_layer_min-1]]
gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
cam_tensors, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
if focus == "Visual Encoder":
cam_grid = cam_tensors.reshape(grid_size, grid_size)
cam = [generate_gradcam(cam_grid, image)]
else:
if target_token_idx != -1:
input_text_decoded = input_ids_decoded[start + target_token_idx]
for i, cam_tensor in enumerate(cam_tensors):
if i == target_token_idx:
cam_grid = cam_tensor.reshape(grid_size, grid_size)
cam_i = generate_gradcam(cam_grid, image)
cam = [add_title_to_image(cam_i, input_text_decoded)]
break
else:
cam = []
for i, cam_tensor in enumerate(cam_tensors):
cam_grid = cam_tensor.reshape(24, 24)
cam_i = generate_gradcam(cam_grid, image)
cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
cam.append(cam_i)
# widths, heights = zip(*(img.size for img in heatmaps))
# total_height = sum(heights)
# max_width = max(widths)
# combined_img = Image.new("RGB", (max_width, total_height))
# y_offset = 0
# for img in heatmaps:
# combined_img.paste(img, (0, y_offset)) # Stack vertically
# y_offset += img.height
# cam = combined_img
elif saliency_map_method == "Attention_Map":
attn_m_token = attention_raw[target_token_idx]
img_token_positions = prepare_inputs.images_seq_mask
mask = img_token_positions[0]
tg = attn_m_token[1][:, :, :, :len(mask)]
tg = tg[:, :, :, mask]
head = 0
# res = tg[0, head, 0].to(torch.float32)
res, _ = tg.max(dim=1)
# res = tg.sum(dim=1)
res = res.to(torch.float32)
grid_size = (int)(res.shape[-1] ** 0.5)
res = res.view(grid_size, grid_size)
cam = [generate_gradcam(res, image)]
# output_arr = output.logits.detach().to(float).to("cpu").numpy()
# predicted_ids = np.argmax(output_arr, axis=-1) # [1, num_tokens]
# predicted_ids = predicted_ids.squeeze(0) # [num_tokens]
# target_token_decoded = tokenizer.decode(predicted_ids[target_token_idx].tolist())
return answer, cam, input_text_decoded
# Gradio interface
def update_sliders(model):
if model == "Clip":
res = (
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
)
return res
else:
res = (
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus")
)
return res
def update_visualization_layers_sliders(focus):
if focus == "Visual Encoder":
res = (
gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
)
return res
else:
res = (
gr.Dropdown(choices=["GradCAM", "Attention_Map"], value="GradCAM", label="saliency map type"),
gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers min"),
gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers max")
)
return res
with gr.Blocks() as demo:
gr.Markdown(value="# Multimodal Understanding")
with gr.Row():
with gr.Column():
image_input = gr.Image()
saliency_map_output = gr.Gallery(label="Saliency Map", columns=1)
with gr.Column():
model_selector = gr.Dropdown(choices=["Clip", "Janus-1B"], value="Clip", label="model")
focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
saliency_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
question_input = gr.Textbox(label="Question")
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
model_selector.change(
fn=update_sliders,
inputs=model_selector,
outputs=[
visualization_layers_min,
visualization_layers_max,
focus
]
)
focus.change(
fn = update_visualization_layers_sliders,
inputs = focus,
outputs=[
saliency_map_method,
visualization_layers_min,
visualization_layers_max,
]
)
understanding_button = gr.Button("Chat")
understanding_output = gr.Textbox(label="Response")
understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
examples_inpainting = gr.Examples(
label="Multimodal Understanding examples",
examples=[
[
"What is the approximate global smartphone market share of Samsung?",
"images/PieChart.png"
],
[
"What is the average internet speed in Japan?",
"images/BarChart.png"
],
[
"What was the average price of coffee beans in October 2019?",
"images/AreaChart.png"
],
[
"Which city's metro system has the largest number of stations?",
"images/BubbleChart.png"
],
[
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
"images/Choropleth_New.png"
],
[
"What distance have customers traveled in the taxi the most?",
"images/Histogram.png"
],
[
"What was the price of a barrel of oil in February 2020?",
"images/LineChart.png"
],
[
"True/False: eBay is nested in the Software category.",
"images/Treemap.png"
],
[
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
"images/Scatterplot.png"
],
[
"Which country has the lowest proportion of Gold medals?",
"images/Stacked100.png"
],
[
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
"images/StackedArea.png"
],
[
"What is the cost of peanuts in Seoul?",
"images/StackedBar.png"
],
# [
# "explain this meme",
# "images/doge.png",
# ],
# [
# "Convert the formula into latex code.",
# "images/equation.png",
# ],
],
inputs=[question_input, image_input],
)
understanding_button.click(
multimodal_understanding,
inputs=[model_selector, saliency_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
visualization_layers_min, visualization_layers_max, focus],
outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
)
demo.launch(share=True)
# demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")