AustingDong commited on
Commit
9d76fc2
·
1 Parent(s): a25a8bd

remove unused files

Browse files
demo/Janus_colab_demo.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
demo/app.py DELETED
@@ -1,224 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM
4
- from janus.models import MultiModalityCausalLM, VLChatProcessor
5
- from PIL import Image
6
-
7
- import numpy as np
8
-
9
-
10
- # Load model and processor
11
- model_path = "deepseek-ai/Janus-1.3B"
12
- config = AutoConfig.from_pretrained(model_path)
13
- language_config = config.language_config
14
- language_config._attn_implementation = 'eager'
15
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
16
- language_config=language_config,
17
- trust_remote_code=True)
18
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
19
-
20
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
21
- tokenizer = vl_chat_processor.tokenizer
22
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
- # Multimodal Understanding function
24
- @torch.inference_mode()
25
- # Multimodal Understanding function
26
- def multimodal_understanding(image, question, seed, top_p, temperature):
27
- # Clear CUDA cache before generating
28
- torch.cuda.empty_cache()
29
-
30
- # set seed
31
- torch.manual_seed(seed)
32
- np.random.seed(seed)
33
- torch.cuda.manual_seed(seed)
34
-
35
- conversation = [
36
- {
37
- "role": "User",
38
- "content": f"<image_placeholder>\n{question}",
39
- "images": [image],
40
- },
41
- {"role": "Assistant", "content": ""},
42
- ]
43
-
44
- pil_images = [Image.fromarray(image)]
45
- prepare_inputs = vl_chat_processor(
46
- conversations=conversation, images=pil_images, force_batchify=True
47
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
48
-
49
-
50
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
51
-
52
- outputs = vl_gpt.language_model.generate(
53
- inputs_embeds=inputs_embeds,
54
- attention_mask=prepare_inputs.attention_mask,
55
- pad_token_id=tokenizer.eos_token_id,
56
- bos_token_id=tokenizer.bos_token_id,
57
- eos_token_id=tokenizer.eos_token_id,
58
- max_new_tokens=512,
59
- do_sample=False if temperature == 0 else True,
60
- use_cache=True,
61
- temperature=temperature,
62
- top_p=top_p,
63
- )
64
-
65
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
66
- return answer
67
-
68
-
69
- def generate(input_ids,
70
- width,
71
- height,
72
- temperature: float = 1,
73
- parallel_size: int = 5,
74
- cfg_weight: float = 5,
75
- image_token_num_per_image: int = 576,
76
- patch_size: int = 16):
77
- # Clear CUDA cache before generating
78
- torch.cuda.empty_cache()
79
-
80
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
81
- for i in range(parallel_size * 2):
82
- tokens[i, :] = input_ids
83
- if i % 2 != 0:
84
- tokens[i, 1:-1] = vl_chat_processor.pad_id
85
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
86
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
87
-
88
- pkv = None
89
- for i in range(image_token_num_per_image):
90
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
91
- use_cache=True,
92
- past_key_values=pkv)
93
- pkv = outputs.past_key_values
94
- hidden_states = outputs.last_hidden_state
95
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
96
- logit_cond = logits[0::2, :]
97
- logit_uncond = logits[1::2, :]
98
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
99
- probs = torch.softmax(logits / temperature, dim=-1)
100
- next_token = torch.multinomial(probs, num_samples=1)
101
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
102
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
103
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
104
- inputs_embeds = img_embeds.unsqueeze(dim=1)
105
- patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
106
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
107
-
108
- return generated_tokens.to(dtype=torch.int), patches
109
-
110
- def unpack(dec, width, height, parallel_size=5):
111
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
112
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
113
-
114
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
115
- visual_img[:, :, :] = dec
116
-
117
- return visual_img
118
-
119
-
120
-
121
- @torch.inference_mode()
122
- def generate_image(prompt,
123
- seed=None,
124
- guidance=5):
125
- # Clear CUDA cache and avoid tracking gradients
126
- torch.cuda.empty_cache()
127
- # Set the seed for reproducible results
128
- if seed is not None:
129
- torch.manual_seed(seed)
130
- torch.cuda.manual_seed(seed)
131
- np.random.seed(seed)
132
- width = 384
133
- height = 384
134
- parallel_size = 5
135
-
136
- with torch.no_grad():
137
- messages = [{'role': 'User', 'content': prompt},
138
- {'role': 'Assistant', 'content': ''}]
139
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
140
- sft_format=vl_chat_processor.sft_format,
141
- system_prompt='')
142
- text = text + vl_chat_processor.image_start_tag
143
- input_ids = torch.LongTensor(tokenizer.encode(text))
144
- output, patches = generate(input_ids,
145
- width // 16 * 16,
146
- height // 16 * 16,
147
- cfg_weight=guidance,
148
- parallel_size=parallel_size)
149
- images = unpack(patches,
150
- width // 16 * 16,
151
- height // 16 * 16)
152
-
153
- return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
154
-
155
-
156
-
157
- # Gradio interface
158
- with gr.Blocks() as demo:
159
- gr.Markdown(value="# Multimodal Understanding")
160
- # with gr.Row():
161
- with gr.Row():
162
- image_input = gr.Image()
163
- with gr.Column():
164
- question_input = gr.Textbox(label="Question")
165
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
166
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
167
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
168
-
169
- understanding_button = gr.Button("Chat")
170
- understanding_output = gr.Textbox(label="Response")
171
-
172
- examples_inpainting = gr.Examples(
173
- label="Multimodal Understanding examples",
174
- examples=[
175
- [
176
- "explain this meme",
177
- "images/doge.png",
178
- ],
179
- [
180
- "Convert the formula into latex code.",
181
- "images/equation.png",
182
- ],
183
- ],
184
- inputs=[question_input, image_input],
185
- )
186
-
187
-
188
- gr.Markdown(value="# Text-to-Image Generation")
189
-
190
-
191
-
192
- with gr.Row():
193
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
194
-
195
- prompt_input = gr.Textbox(label="Prompt")
196
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
197
-
198
- generation_button = gr.Button("Generate Images")
199
-
200
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
201
-
202
- examples_t2i = gr.Examples(
203
- label="Text to image generation examples. (Tips for designing prompts: Adding description like 'digital art' at the end of the prompt or writing the prompt in more detail can help produce better images!)",
204
- examples=[
205
- "Master shifu racoon wearing drip attire as a street gangster.",
206
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
207
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
208
- ],
209
- inputs=prompt_input,
210
- )
211
-
212
- understanding_button.click(
213
- multimodal_understanding,
214
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
215
- outputs=understanding_output
216
- )
217
-
218
- generation_button.click(
219
- fn=generate_image,
220
- inputs=[prompt_input, seed_input, cfg_weight_input],
221
- outputs=image_output
222
- )
223
-
224
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/app_janusflow.py DELETED
@@ -1,247 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
4
- from PIL import Image
5
- from diffusers.models import AutoencoderKL
6
- import numpy as np
7
-
8
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
-
10
- # Load model and processor
11
- model_path = "deepseek-ai/JanusFlow-1.3B"
12
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
13
- tokenizer = vl_chat_processor.tokenizer
14
-
15
- vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
16
- vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
17
-
18
- # remember to use bfloat16 dtype, this vae doesn't work with fp16
19
- vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
20
- vae = vae.to(torch.bfloat16).to(cuda_device).eval()
21
-
22
- # Multimodal Understanding function
23
- @torch.inference_mode()
24
- # Multimodal Understanding function
25
- def multimodal_understanding(image, question, seed, top_p, temperature):
26
- # Clear CUDA cache before generating
27
- torch.cuda.empty_cache()
28
-
29
- # set seed
30
- torch.manual_seed(seed)
31
- np.random.seed(seed)
32
- torch.cuda.manual_seed(seed)
33
-
34
- conversation = [
35
- {
36
- "role": "User",
37
- "content": f"<image_placeholder>\n{question}",
38
- "images": [image],
39
- },
40
- {"role": "Assistant", "content": ""},
41
- ]
42
-
43
- pil_images = [Image.fromarray(image)]
44
- prepare_inputs = vl_chat_processor(
45
- conversations=conversation, images=pil_images, force_batchify=True
46
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
47
-
48
-
49
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
50
-
51
- outputs = vl_gpt.language_model.generate(
52
- inputs_embeds=inputs_embeds,
53
- attention_mask=prepare_inputs.attention_mask,
54
- pad_token_id=tokenizer.eos_token_id,
55
- bos_token_id=tokenizer.bos_token_id,
56
- eos_token_id=tokenizer.eos_token_id,
57
- max_new_tokens=512,
58
- do_sample=False if temperature == 0 else True,
59
- use_cache=True,
60
- temperature=temperature,
61
- top_p=top_p,
62
- )
63
-
64
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
65
-
66
- return answer
67
-
68
-
69
- @torch.inference_mode()
70
- def generate(
71
- input_ids,
72
- cfg_weight: float = 2.0,
73
- num_inference_steps: int = 30
74
- ):
75
- # we generate 5 images at a time, *2 for CFG
76
- tokens = torch.stack([input_ids] * 10).cuda()
77
- tokens[5:, 1:] = vl_chat_processor.pad_id
78
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
79
- print(inputs_embeds.shape)
80
-
81
- # we remove the last <bog> token and replace it with t_emb later
82
- inputs_embeds = inputs_embeds[:, :-1, :]
83
-
84
- # generate with rectified flow ode
85
- # step 1: encode with vision_gen_enc
86
- z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
87
-
88
- dt = 1.0 / num_inference_steps
89
- dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
90
-
91
- # step 2: run ode
92
- attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
93
- attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
94
- attention_mask = attention_mask.int()
95
- for step in range(num_inference_steps):
96
- # prepare inputs for the llm
97
- z_input = torch.cat([z, z], dim=0) # for cfg
98
- t = step / num_inference_steps * 1000.
99
- t = torch.tensor([t] * z_input.shape[0]).to(dt)
100
- z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
101
- z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
102
- z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
103
- z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
104
- llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
105
-
106
- # input to the llm
107
- # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
108
- if step == 0:
109
- outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
110
- use_cache=True,
111
- attention_mask=attention_mask,
112
- past_key_values=None)
113
- past_key_values = []
114
- for kv_cache in past_key_values:
115
- k, v = kv_cache[0], kv_cache[1]
116
- past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
117
- past_key_values = tuple(past_key_values)
118
- else:
119
- outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
120
- use_cache=True,
121
- attention_mask=attention_mask,
122
- past_key_values=past_key_values)
123
- hidden_states = outputs.last_hidden_state
124
-
125
- # transform hidden_states back to v
126
- hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
127
- hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
128
- v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
129
- v_cond, v_uncond = torch.chunk(v, 2)
130
- v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
131
- z = z + dt * v
132
-
133
- # step 3: decode with vision_gen_dec and sdxl vae
134
- decoded_image = vae.decode(z / vae.config.scaling_factor).sample
135
-
136
- images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
137
- images = ((images+1) / 2. * 255).astype(np.uint8)
138
-
139
- return images
140
-
141
- def unpack(dec, width, height, parallel_size=5):
142
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
143
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
144
-
145
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
146
- visual_img[:, :, :] = dec
147
-
148
- return visual_img
149
-
150
-
151
- @torch.inference_mode()
152
- def generate_image(prompt,
153
- seed=None,
154
- guidance=5,
155
- num_inference_steps=30):
156
- # Clear CUDA cache and avoid tracking gradients
157
- torch.cuda.empty_cache()
158
- # Set the seed for reproducible results
159
- if seed is not None:
160
- torch.manual_seed(seed)
161
- torch.cuda.manual_seed(seed)
162
- np.random.seed(seed)
163
-
164
- with torch.no_grad():
165
- messages = [{'role': 'User', 'content': prompt},
166
- {'role': 'Assistant', 'content': ''}]
167
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
168
- sft_format=vl_chat_processor.sft_format,
169
- system_prompt='')
170
- text = text + vl_chat_processor.image_start_tag
171
- input_ids = torch.LongTensor(tokenizer.encode(text))
172
- images = generate(input_ids,
173
- cfg_weight=guidance,
174
- num_inference_steps=num_inference_steps)
175
- return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
176
-
177
-
178
-
179
- # Gradio interface
180
- with gr.Blocks() as demo:
181
- gr.Markdown(value="# Multimodal Understanding")
182
- # with gr.Row():
183
- with gr.Row():
184
- image_input = gr.Image()
185
- with gr.Column():
186
- question_input = gr.Textbox(label="Question")
187
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
188
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
189
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
190
-
191
- understanding_button = gr.Button("Chat")
192
- understanding_output = gr.Textbox(label="Response")
193
-
194
- examples_inpainting = gr.Examples(
195
- label="Multimodal Understanding examples",
196
- examples=[
197
- [
198
- "explain this meme",
199
- "./images/doge.png",
200
- ],
201
- [
202
- "Convert the formula into latex code.",
203
- "./images/equation.png",
204
- ],
205
- ],
206
- inputs=[question_input, image_input],
207
- )
208
-
209
-
210
- gr.Markdown(value="# Text-to-Image Generation")
211
-
212
-
213
-
214
- with gr.Row():
215
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
216
- step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
217
-
218
- prompt_input = gr.Textbox(label="Prompt")
219
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
220
-
221
- generation_button = gr.Button("Generate Images")
222
-
223
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
224
-
225
- examples_t2i = gr.Examples(
226
- label="Text to image generation examples.",
227
- examples=[
228
- "Master shifu racoon wearing drip attire as a street gangster.",
229
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
230
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
231
- ],
232
- inputs=prompt_input,
233
- )
234
-
235
- understanding_button.click(
236
- multimodal_understanding,
237
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
238
- outputs=understanding_output
239
- )
240
-
241
- generation_button.click(
242
- fn=generate_image,
243
- inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
244
- outputs=image_output
245
- )
246
-
247
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/app_januspro.py DELETED
@@ -1,294 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM
4
- from janus.models import MultiModalityCausalLM, VLChatProcessor
5
- from janus.utils.io import load_pil_images
6
- from demo.cam import generate_gradcam, GradCAM, AttentionGuidedCAM
7
- from PIL import Image
8
- from einops import rearrange
9
-
10
- import numpy as np
11
- import os
12
- import time
13
- # import spaces # Import spaces for ZeroGPU compatibility
14
-
15
-
16
- # Load model and processor
17
- # model_path = "deepseek-ai/Janus-Pro-7B"
18
- model_path = "deepseek-ai/Janus-Pro-1B"
19
- config = AutoConfig.from_pretrained(model_path)
20
- language_config = config.language_config
21
- language_config._attn_implementation = 'eager'
22
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
23
- language_config=language_config,
24
- trust_remote_code=True)
25
-
26
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
27
- # dtype = torch.bfloat32 if torch.cuda.is_available() else torch.float32
28
-
29
- if torch.cuda.is_available():
30
- vl_gpt = vl_gpt.to(dtype).cuda()
31
- else:
32
- # vl_gpt = vl_gpt.to(torch.float16)
33
- torch.set_default_device("mps")
34
- vl_gpt = vl_gpt.to(dtype)
35
-
36
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
37
- tokenizer = vl_chat_processor.tokenizer
38
- cuda_device = 'cuda' if torch.cuda.is_available() else 'mps'
39
-
40
-
41
-
42
- # @torch.inference_mode() # cancel inference, for gradcam
43
- # @spaces.GPU(duration=120)
44
- # Multimodal Understanding function
45
- def multimodal_understanding(image, question, seed, top_p, temperature, target_token_idx):
46
- # Clear CUDA cache before generating
47
- torch.cuda.empty_cache()
48
-
49
-
50
- for param in vl_gpt.parameters():
51
- param.requires_grad = True
52
-
53
- # set seed
54
- torch.manual_seed(seed)
55
- np.random.seed(seed)
56
- torch.cuda.manual_seed(seed)
57
-
58
-
59
- # Get the last transformer block of the Vision Transformer (ViT)
60
-
61
-
62
- conversation = [
63
- {
64
- "role": "<|User|>",
65
- "content": f"<image_placeholder>\n{question}",
66
- "images": [image],
67
- },
68
- {"role": "<|Assistant|>", "content": ""},
69
- ]
70
-
71
- pil_images = [Image.fromarray(image)]
72
- prepare_inputs = vl_chat_processor(
73
- conversations=conversation, images=pil_images, force_batchify=True
74
- ).to(cuda_device, dtype=dtype)
75
-
76
-
77
-
78
-
79
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
80
-
81
- # print("prepared inputs", prepare_inputs)
82
-
83
-
84
- outputs = vl_gpt.language_model.generate(
85
- inputs_embeds=inputs_embeds,
86
- attention_mask=prepare_inputs.attention_mask,
87
- pad_token_id=tokenizer.eos_token_id,
88
- bos_token_id=tokenizer.bos_token_id,
89
- eos_token_id=tokenizer.eos_token_id,
90
- max_new_tokens=512,
91
- do_sample=False if temperature == 0 else True,
92
- use_cache=True,
93
- temperature=temperature,
94
- top_p=top_p,
95
- )
96
-
97
-
98
-
99
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
100
- print("answer generated")
101
-
102
-
103
- target_layer = vl_gpt.vision_model.vision_tower.blocks
104
-
105
- gradcam = AttentionGuidedCAM(vl_gpt, target_layer)
106
- cam_tensor, output, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx)
107
- cam_grid = cam_tensor.reshape(grid_size, grid_size)
108
- cam = generate_gradcam(cam_grid, image)
109
-
110
- output_arr = output.logits.detach().to(float).to("cpu").numpy()
111
- predicted_ids = np.argmax(output_arr, axis=-1) # [1, num_tokens]
112
- predicted_ids = predicted_ids.squeeze(0) # [num_tokens]
113
- target_token_decoded = tokenizer.decode(predicted_ids[target_token_idx].tolist())
114
-
115
- return answer, [cam], target_token_decoded
116
-
117
-
118
- def generate(input_ids,
119
- width,
120
- height,
121
- temperature: float = 1,
122
- parallel_size: int = 5,
123
- cfg_weight: float = 5,
124
- image_token_num_per_image: int = 576,
125
- patch_size: int = 16):
126
- # Clear CUDA cache before generating
127
- torch.cuda.empty_cache()
128
-
129
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
130
- for i in range(parallel_size * 2):
131
- tokens[i, :] = input_ids
132
- if i % 2 != 0:
133
- tokens[i, 1:-1] = vl_chat_processor.pad_id
134
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
135
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
136
-
137
- pkv = None
138
- for i in range(image_token_num_per_image):
139
- with torch.no_grad():
140
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
141
- use_cache=True,
142
- past_key_values=pkv)
143
- pkv = outputs.past_key_values
144
- hidden_states = outputs.last_hidden_state
145
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
146
- logit_cond = logits[0::2, :]
147
- logit_uncond = logits[1::2, :]
148
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
149
- probs = torch.softmax(logits / temperature, dim=-1)
150
- next_token = torch.multinomial(probs, num_samples=1)
151
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
152
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
153
-
154
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
155
- inputs_embeds = img_embeds.unsqueeze(dim=1)
156
-
157
-
158
-
159
- patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
160
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
161
-
162
- return generated_tokens.to(dtype=torch.int), patches
163
-
164
- def unpack(dec, width, height, parallel_size=5):
165
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
166
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
167
-
168
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
169
- visual_img[:, :, :] = dec
170
-
171
- return visual_img
172
-
173
-
174
-
175
- @torch.inference_mode()
176
- # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
177
- def generate_image(prompt,
178
- seed=None,
179
- guidance=5,
180
- t2i_temperature=1.0):
181
- # Clear CUDA cache and avoid tracking gradients
182
- torch.cuda.empty_cache()
183
- # Set the seed for reproducible results
184
- if seed is not None:
185
- torch.manual_seed(seed)
186
- torch.cuda.manual_seed(seed)
187
- np.random.seed(seed)
188
- width = 384
189
- height = 384
190
- parallel_size = 5
191
-
192
- with torch.no_grad():
193
- messages = [{'role': '<|User|>', 'content': prompt},
194
- {'role': '<|Assistant|>', 'content': ''}]
195
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
196
- sft_format=vl_chat_processor.sft_format,
197
- system_prompt='')
198
- text = text + vl_chat_processor.image_start_tag
199
-
200
- input_ids = torch.LongTensor(tokenizer.encode(text))
201
- output, patches = generate(input_ids,
202
- width // 16 * 16,
203
- height // 16 * 16,
204
- cfg_weight=guidance,
205
- parallel_size=parallel_size,
206
- temperature=t2i_temperature)
207
- images = unpack(patches,
208
- width // 16 * 16,
209
- height // 16 * 16,
210
- parallel_size=parallel_size)
211
-
212
- return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
213
-
214
-
215
- # Gradio interface
216
- with gr.Blocks() as demo:
217
- gr.Markdown(value="# Multimodal Understanding")
218
- with gr.Row():
219
- with gr.Column():
220
- image_input = gr.Image()
221
- saliency_map_output = gr.Gallery(label="Saliency Map", columns=1, rows=1, height=300)
222
-
223
- with gr.Column():
224
- question_input = gr.Textbox(label="Question")
225
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
226
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
227
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
228
- target_token_idx = gr.Number(label="target_token_idx", precision=0, value=300)
229
-
230
- understanding_button = gr.Button("Chat")
231
- understanding_output = gr.Textbox(label="Response")
232
- understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
233
-
234
-
235
- examples_inpainting = gr.Examples(
236
- label="Multimodal Understanding examples",
237
- examples=[
238
- [
239
- "explain this meme",
240
- "images/doge.png",
241
- ],
242
- [
243
- "Convert the formula into latex code.",
244
- "images/equation.png",
245
- ],
246
- ],
247
- inputs=[question_input, image_input],
248
- )
249
-
250
-
251
-
252
-
253
- gr.Markdown(value="# Text-to-Image Generation")
254
-
255
-
256
-
257
- with gr.Row():
258
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
259
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
260
-
261
- prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
262
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
263
-
264
- generation_button = gr.Button("Generate Images")
265
-
266
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
267
-
268
- examples_t2i = gr.Examples(
269
- label="Text to image generation examples.",
270
- examples=[
271
- "Master shifu racoon wearing drip attire as a street gangster.",
272
- "The face of a beautiful girl",
273
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
274
- "A glass of red wine on a reflective surface.",
275
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
276
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
277
- ],
278
- inputs=prompt_input,
279
- )
280
-
281
- understanding_button.click(
282
- multimodal_understanding,
283
- inputs=[image_input, question_input, und_seed_input, top_p, temperature, target_token_idx],
284
- outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
285
- )
286
-
287
- generation_button.click(
288
- fn=generate_image,
289
- inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
290
- outputs=image_output
291
- )
292
-
293
- demo.launch(share=True)
294
- # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/app_vqa.py DELETED
@@ -1,333 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM
4
- from janus.models import MultiModalityCausalLM, VLChatProcessor
5
- from janus.utils.io import load_pil_images
6
- from demo.cam import generate_gradcam, AttentionGuidedCAMJanus, AttentionGuidedCAMClip
7
- from demo.model_utils import Clip_Utils, Janus_Utils, add_title_to_image
8
-
9
- import numpy as np
10
- import matplotlib.pyplot as plt
11
- import gc
12
- from PIL import Image
13
-
14
- model_seed = 42
15
- torch.manual_seed(model_seed)
16
- np.random.seed(model_seed)
17
- torch.cuda.manual_seed(model_seed)
18
-
19
- model_type = "Janus-1B"
20
- janus_utils = Janus_Utils()
21
- vl_gpt, tokenizer = janus_utils.init_Janus(model_type.split('-')[-1])
22
-
23
- clip_utils = Clip_Utils()
24
- clip_utils.init_Clip()
25
-
26
- # @torch.inference_mode() # cancel inference, for gradcam
27
- # @spaces.GPU(duration=120)
28
- # Multimodal Understanding function
29
- def multimodal_understanding(model_type,
30
- saliency_map_method,
31
- visual_pooling_method,
32
- image, question, seed, top_p, temperature, target_token_idx,
33
- visualization_layer_min, visualization_layer_max, focus):
34
- # Clear CUDA cache before generating
35
- torch.cuda.empty_cache()
36
-
37
- # set seed
38
- torch.manual_seed(seed)
39
- np.random.seed(seed)
40
- torch.cuda.manual_seed(seed)
41
-
42
- input_text_decoded = ""
43
- if model_type == "Clip":
44
-
45
- inputs = clip_utils.prepare_inputs([question], image)
46
-
47
-
48
- if saliency_map_method == "GradCAM":
49
- # Generate Grad-CAM
50
- all_layers = [layer.layer_norm1 for layer in clip_utils.model.vision_model.encoder.layers]
51
- if visualization_layers_min.value != visualization_layers_max.value:
52
- target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
53
- else:
54
- target_layers = [all_layers[visualization_layer_min-1]]
55
- grad_cam = AttentionGuidedCAMClip(clip_utils.model, target_layers)
56
- cam, outputs, grid_size = grad_cam.generate_cam(inputs, class_idx=0, visual_pooling_method=visual_pooling_method)
57
- cam = [generate_gradcam(cam, image, size=(224, 224))]
58
- grad_cam.remove_hooks()
59
- target_token_decoded = ""
60
- answer = ""
61
-
62
-
63
- elif model_type == "Janus-1B":
64
-
65
- for param in vl_gpt.parameters():
66
- param.requires_grad = True
67
-
68
-
69
- prepare_inputs = janus_utils.prepare_inputs(question, image)
70
- inputs_embeds = janus_utils.generate_inputs_embeddings(prepare_inputs)
71
- outputs = janus_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
72
-
73
- sequences = outputs.sequences.cpu().tolist()
74
- answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
75
- attention_raw = outputs.attentions
76
- print("answer generated")
77
-
78
- input_ids = prepare_inputs.input_ids[0].cpu().tolist()
79
- input_ids_decoded = [tokenizer.decode([input_ids[i]]) for i in range(len(input_ids))]
80
- start=620
81
-
82
- if saliency_map_method == "GradCAM":
83
- # target_layers = vl_gpt.vision_model.vision_tower.blocks
84
- if focus == "Visual Encoder":
85
- all_layers = [block.norm1 for block in vl_gpt.vision_model.vision_tower.blocks]
86
- else:
87
- all_layers = [layer.self_attn for layer in vl_gpt.language_model.model.layers]
88
-
89
- if visualization_layers_min.value != visualization_layers_max.value:
90
- target_layers = all_layers[visualization_layer_min-1 : visualization_layer_max-1]
91
- else:
92
- target_layers = [all_layers[visualization_layer_min-1]]
93
-
94
- gradcam = AttentionGuidedCAMJanus(vl_gpt, target_layers)
95
- cam_tensors, grid_size = gradcam.generate_cam(prepare_inputs, tokenizer, temperature, top_p, target_token_idx, visual_pooling_method, focus)
96
- if focus == "Visual Encoder":
97
- cam_grid = cam_tensors.reshape(grid_size, grid_size)
98
- cam = [generate_gradcam(cam_grid, image)]
99
- else:
100
- if target_token_idx != -1:
101
- input_text_decoded = input_ids_decoded[start + target_token_idx]
102
- for i, cam_tensor in enumerate(cam_tensors):
103
- if i == target_token_idx:
104
- cam_grid = cam_tensor.reshape(grid_size, grid_size)
105
- cam_i = generate_gradcam(cam_grid, image)
106
- cam = [add_title_to_image(cam_i, input_text_decoded)]
107
- break
108
- else:
109
- cam = []
110
- for i, cam_tensor in enumerate(cam_tensors):
111
- cam_grid = cam_tensor.reshape(24, 24)
112
- cam_i = generate_gradcam(cam_grid, image)
113
- cam_i = add_title_to_image(cam_i, input_ids_decoded[start + i])
114
-
115
- cam.append(cam_i)
116
-
117
- # widths, heights = zip(*(img.size for img in heatmaps))
118
- # total_height = sum(heights)
119
- # max_width = max(widths)
120
-
121
- # combined_img = Image.new("RGB", (max_width, total_height))
122
-
123
- # y_offset = 0
124
- # for img in heatmaps:
125
- # combined_img.paste(img, (0, y_offset)) # Stack vertically
126
- # y_offset += img.height
127
- # cam = combined_img
128
-
129
-
130
-
131
-
132
- elif saliency_map_method == "Attention_Map":
133
- attn_m_token = attention_raw[target_token_idx]
134
- img_token_positions = prepare_inputs.images_seq_mask
135
- mask = img_token_positions[0]
136
-
137
- tg = attn_m_token[1][:, :, :, :len(mask)]
138
- tg = tg[:, :, :, mask]
139
- head = 0
140
-
141
- # res = tg[0, head, 0].to(torch.float32)
142
- res, _ = tg.max(dim=1)
143
- # res = tg.sum(dim=1)
144
- res = res.to(torch.float32)
145
- grid_size = (int)(res.shape[-1] ** 0.5)
146
- res = res.view(grid_size, grid_size)
147
- cam = [generate_gradcam(res, image)]
148
-
149
-
150
- # output_arr = output.logits.detach().to(float).to("cpu").numpy()
151
- # predicted_ids = np.argmax(output_arr, axis=-1) # [1, num_tokens]
152
- # predicted_ids = predicted_ids.squeeze(0) # [num_tokens]
153
- # target_token_decoded = tokenizer.decode(predicted_ids[target_token_idx].tolist())
154
-
155
-
156
- return answer, cam, input_text_decoded
157
-
158
-
159
-
160
-
161
- # Gradio interface
162
-
163
- def update_sliders(model):
164
- if model == "Clip":
165
- res = (
166
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min"),
167
- gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max"),
168
- gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
169
- )
170
- return res
171
- else:
172
- res = (
173
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
174
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max"),
175
- gr.Dropdown(choices=["Visual Encoder", "Language Model"], value="Visual Encoder", label="focus")
176
- )
177
- return res
178
-
179
- def update_visualization_layers_sliders(focus):
180
- if focus == "Visual Encoder":
181
- res = (
182
- gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type"),
183
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers min"),
184
- gr.Slider(minimum=1, maximum=24, value=24, step=1, label="visualization layers max")
185
- )
186
- return res
187
- else:
188
- res = (
189
- gr.Dropdown(choices=["GradCAM", "Attention_Map"], value="GradCAM", label="saliency map type"),
190
- gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers min"),
191
- gr.Slider(minimum=1, maximum=24, value=9, step=1, label="visualization layers max")
192
- )
193
- return res
194
-
195
- with gr.Blocks() as demo:
196
- gr.Markdown(value="# Multimodal Understanding")
197
- with gr.Row():
198
- with gr.Column():
199
- image_input = gr.Image()
200
- saliency_map_output = gr.Gallery(label="Saliency Map", columns=1)
201
-
202
- with gr.Column():
203
- model_selector = gr.Dropdown(choices=["Clip", "Janus-1B"], value="Clip", label="model")
204
- focus = gr.Dropdown(choices=["Visual Encoder"], value="Visual Encoder", label="focus")
205
- saliency_map_method = gr.Dropdown(choices=["GradCAM"], value="GradCAM", label="saliency map type")
206
- visual_pooling_method = gr.Dropdown(choices=["CLS", "max", "avg"], value="CLS", label="visual pooling method")
207
-
208
-
209
- visualization_layers_min = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers min")
210
- visualization_layers_max = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="visualization layers max")
211
-
212
- question_input = gr.Textbox(label="Question")
213
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
214
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
215
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
216
- target_token_idx = gr.Number(label="target_token_idx (-1 means all)", precision=0, value=-1)
217
-
218
-
219
-
220
- model_selector.change(
221
- fn=update_sliders,
222
- inputs=model_selector,
223
- outputs=[
224
- visualization_layers_min,
225
- visualization_layers_max,
226
- focus
227
- ]
228
- )
229
-
230
- focus.change(
231
- fn = update_visualization_layers_sliders,
232
- inputs = focus,
233
- outputs=[
234
- saliency_map_method,
235
- visualization_layers_min,
236
- visualization_layers_max,
237
- ]
238
- )
239
-
240
-
241
-
242
- understanding_button = gr.Button("Chat")
243
- understanding_output = gr.Textbox(label="Response")
244
- understanding_target_token_decoded_output = gr.Textbox(label="Target Token Decoded")
245
-
246
-
247
- examples_inpainting = gr.Examples(
248
- label="Multimodal Understanding examples",
249
- examples=[
250
-
251
- [
252
- "What is the approximate global smartphone market share of Samsung?",
253
- "images/PieChart.png"
254
- ],
255
- [
256
- "What is the average internet speed in Japan?",
257
- "images/BarChart.png"
258
- ],
259
- [
260
- "What was the average price of coffee beans in October 2019?",
261
- "images/AreaChart.png"
262
- ],
263
- [
264
- "Which city's metro system has the largest number of stations?",
265
- "images/BubbleChart.png"
266
- ],
267
-
268
- [
269
- "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
270
- "images/Choropleth_New.png"
271
- ],
272
-
273
- [
274
- "What distance have customers traveled in the taxi the most?",
275
- "images/Histogram.png"
276
- ],
277
-
278
- [
279
- "What was the price of a barrel of oil in February 2020?",
280
- "images/LineChart.png"
281
- ],
282
-
283
- [
284
- "True/False: eBay is nested in the Software category.",
285
- "images/Treemap.png"
286
- ],
287
-
288
- [
289
- "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
290
- "images/Scatterplot.png"
291
- ],
292
-
293
- [
294
- "Which country has the lowest proportion of Gold medals?",
295
- "images/Stacked100.png"
296
- ],
297
-
298
- [
299
- "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
300
- "images/StackedArea.png"
301
- ],
302
-
303
- [
304
- "What is the cost of peanuts in Seoul?",
305
- "images/StackedBar.png"
306
- ],
307
-
308
-
309
- # [
310
- # "explain this meme",
311
- # "images/doge.png",
312
- # ],
313
- # [
314
- # "Convert the formula into latex code.",
315
- # "images/equation.png",
316
- # ],
317
-
318
- ],
319
- inputs=[question_input, image_input],
320
- )
321
-
322
-
323
-
324
-
325
- understanding_button.click(
326
- multimodal_understanding,
327
- inputs=[model_selector, saliency_map_method, visual_pooling_method, image_input, question_input, und_seed_input, top_p, temperature, target_token_idx,
328
- visualization_layers_min, visualization_layers_max, focus],
329
- outputs=[understanding_output, saliency_map_output, understanding_target_token_decoded_output]
330
- )
331
-
332
- demo.launch(share=True)
333
- # demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/demo.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
demo/demo_attn.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
demo/fastapi_app.py DELETED
@@ -1,178 +0,0 @@
1
- from fastapi import FastAPI, File, Form, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse, StreamingResponse
3
- import torch
4
- from transformers import AutoConfig, AutoModelForCausalLM
5
- from janus.models import MultiModalityCausalLM, VLChatProcessor
6
- from PIL import Image
7
- import numpy as np
8
- import io
9
-
10
- app = FastAPI()
11
-
12
- # Load model and processor
13
- model_path = "deepseek-ai/Janus-1.3B"
14
- config = AutoConfig.from_pretrained(model_path)
15
- language_config = config.language_config
16
- language_config._attn_implementation = 'eager'
17
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
18
- language_config=language_config,
19
- trust_remote_code=True)
20
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
21
-
22
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
23
- tokenizer = vl_chat_processor.tokenizer
24
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
-
26
-
27
- @torch.inference_mode()
28
- def multimodal_understanding(image_data, question, seed, top_p, temperature):
29
- torch.cuda.empty_cache()
30
- torch.manual_seed(seed)
31
- np.random.seed(seed)
32
- torch.cuda.manual_seed(seed)
33
-
34
- conversation = [
35
- {
36
- "role": "User",
37
- "content": f"<image_placeholder>\n{question}",
38
- "images": [image_data],
39
- },
40
- {"role": "Assistant", "content": ""},
41
- ]
42
-
43
- pil_images = [Image.open(io.BytesIO(image_data))]
44
- prepare_inputs = vl_chat_processor(
45
- conversations=conversation, images=pil_images, force_batchify=True
46
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
47
-
48
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
49
- outputs = vl_gpt.language_model.generate(
50
- inputs_embeds=inputs_embeds,
51
- attention_mask=prepare_inputs.attention_mask,
52
- pad_token_id=tokenizer.eos_token_id,
53
- bos_token_id=tokenizer.bos_token_id,
54
- eos_token_id=tokenizer.eos_token_id,
55
- max_new_tokens=512,
56
- do_sample=False if temperature == 0 else True,
57
- use_cache=True,
58
- temperature=temperature,
59
- top_p=top_p,
60
- )
61
-
62
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
63
- return answer
64
-
65
-
66
- @app.post("/understand_image_and_question/")
67
- async def understand_image_and_question(
68
- file: UploadFile = File(...),
69
- question: str = Form(...),
70
- seed: int = Form(42),
71
- top_p: float = Form(0.95),
72
- temperature: float = Form(0.1)
73
- ):
74
- image_data = await file.read()
75
- response = multimodal_understanding(image_data, question, seed, top_p, temperature)
76
- return JSONResponse({"response": response})
77
-
78
-
79
- def generate(input_ids,
80
- width,
81
- height,
82
- temperature: float = 1,
83
- parallel_size: int = 5,
84
- cfg_weight: float = 5,
85
- image_token_num_per_image: int = 576,
86
- patch_size: int = 16):
87
- torch.cuda.empty_cache()
88
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
89
- for i in range(parallel_size * 2):
90
- tokens[i, :] = input_ids
91
- if i % 2 != 0:
92
- tokens[i, 1:-1] = vl_chat_processor.pad_id
93
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
94
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
95
-
96
- pkv = None
97
- for i in range(image_token_num_per_image):
98
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
99
- pkv = outputs.past_key_values
100
- hidden_states = outputs.last_hidden_state
101
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
102
- logit_cond = logits[0::2, :]
103
- logit_uncond = logits[1::2, :]
104
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
105
- probs = torch.softmax(logits / temperature, dim=-1)
106
- next_token = torch.multinomial(probs, num_samples=1)
107
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
108
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
109
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
110
- inputs_embeds = img_embeds.unsqueeze(dim=1)
111
- patches = vl_gpt.gen_vision_model.decode_code(
112
- generated_tokens.to(dtype=torch.int),
113
- shape=[parallel_size, 8, width // patch_size, height // patch_size]
114
- )
115
-
116
- return generated_tokens.to(dtype=torch.int), patches
117
-
118
-
119
- def unpack(dec, width, height, parallel_size=5):
120
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
121
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
122
-
123
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
124
- visual_img[:, :, :] = dec
125
-
126
- return visual_img
127
-
128
-
129
- @torch.inference_mode()
130
- def generate_image(prompt, seed, guidance):
131
- torch.cuda.empty_cache()
132
- seed = seed if seed is not None else 12345
133
- torch.manual_seed(seed)
134
- torch.cuda.manual_seed(seed)
135
- np.random.seed(seed)
136
- width = 384
137
- height = 384
138
- parallel_size = 5
139
-
140
- with torch.no_grad():
141
- messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
142
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
143
- conversations=messages,
144
- sft_format=vl_chat_processor.sft_format,
145
- system_prompt=''
146
- )
147
- text = text + vl_chat_processor.image_start_tag
148
- input_ids = torch.LongTensor(tokenizer.encode(text))
149
- _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
150
- images = unpack(patches, width // 16 * 16, height // 16 * 16)
151
-
152
- return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
153
-
154
-
155
- @app.post("/generate_images/")
156
- async def generate_images(
157
- prompt: str = Form(...),
158
- seed: int = Form(None),
159
- guidance: float = Form(5.0),
160
- ):
161
- try:
162
- images = generate_image(prompt, seed, guidance)
163
- def image_stream():
164
- for img in images:
165
- buf = io.BytesIO()
166
- img.save(buf, format='PNG')
167
- buf.seek(0)
168
- yield buf.read()
169
-
170
- return StreamingResponse(image_stream(), media_type="multipart/related")
171
- except Exception as e:
172
- raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
173
-
174
-
175
-
176
- if __name__ == "__main__":
177
- import uvicorn
178
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/fastapi_client.py DELETED
@@ -1,78 +0,0 @@
1
- import requests
2
- from PIL import Image
3
- import io
4
- # Endpoint URLs
5
- understand_image_url = "http://localhost:8000/understand_image_and_question/"
6
- generate_images_url = "http://localhost:8000/generate_images/"
7
-
8
- # Use your image file path here
9
- image_path = "images/equation.png"
10
-
11
- # Function to call the image understanding endpoint
12
- def understand_image_and_question(image_path, question, seed=42, top_p=0.95, temperature=0.1):
13
- files = {'file': open(image_path, 'rb')}
14
- data = {
15
- 'question': question,
16
- 'seed': seed,
17
- 'top_p': top_p,
18
- 'temperature': temperature
19
- }
20
- response = requests.post(understand_image_url, files=files, data=data)
21
- response_data = response.json()
22
- print("Image Understanding Response:", response_data['response'])
23
-
24
-
25
- # Function to call the text-to-image generation endpoint
26
- def generate_images(prompt, seed=None, guidance=5.0):
27
- data = {
28
- 'prompt': prompt,
29
- 'seed': seed,
30
- 'guidance': guidance
31
- }
32
- response = requests.post(generate_images_url, data=data, stream=True)
33
-
34
- if response.ok:
35
- img_idx = 1
36
-
37
- # We will create a new BytesIO for each image
38
- buffers = {}
39
-
40
- try:
41
- for chunk in response.iter_content(chunk_size=1024):
42
- if chunk:
43
- # Use a boundary detection to determine new image start
44
- if img_idx not in buffers:
45
- buffers[img_idx] = io.BytesIO()
46
-
47
- buffers[img_idx].write(chunk)
48
-
49
- # Attempt to open the image
50
- try:
51
- buffer = buffers[img_idx]
52
- buffer.seek(0)
53
- image = Image.open(buffer)
54
- img_path = f"generated_image_{img_idx}.png"
55
- image.save(img_path)
56
- print(f"Saved: {img_path}")
57
-
58
- # Prepare the next image buffer
59
- buffer.close()
60
- img_idx += 1
61
-
62
- except Exception as e:
63
- # Continue loading data into the current buffer
64
- continue
65
-
66
- except Exception as e:
67
- print("Error processing image:", e)
68
- else:
69
- print("Failed to generate images.")
70
-
71
-
72
- # Example usage
73
- if __name__ == "__main__":
74
- # Call the image understanding API
75
- understand_image_and_question(image_path, "What is this image about?")
76
-
77
- # Call the image generation API
78
- generate_images("A beautiful sunset over a mountain range, digital art.")