prithivMLmods commited on
Commit
d48854d
·
verified ·
1 Parent(s): 1a48c6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -418
app.py CHANGED
@@ -1,431 +1,343 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
  import gradio as gr
10
  import spaces
 
 
11
  import torch
12
- import numpy as np
13
  from PIL import Image
14
- import cv2
15
-
16
- from transformers import (
17
- AutoModelForCausalLM,
18
- AutoTokenizer,
19
- TextIteratorStreamer,
20
- Qwen2VLForConditionalGeneration,
21
- AutoProcessor,
22
- )
23
- from transformers.image_utils import load_image
24
-
25
- # Additional imports for new TTS
26
- from snac import SNAC
27
- from huggingface_hub import snapshot_download
28
- from dotenv import load_dotenv
29
- load_dotenv()
30
-
31
- # Set up device
32
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
- tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
34
-
35
- # Load DeepHermes Llama (chat/LLM) model
36
- hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
37
- hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
38
- hermes_llm_model = AutoModelForCausalLM.from_pretrained(
39
- hermes_model_id,
40
- device_map="auto",
41
- torch_dtype=torch.bfloat16,
42
- )
43
- hermes_llm_model.eval()
44
-
45
- # Load Qwen2-VL processor and model for multimodal tasks (e.g. video processing)
46
- MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
47
- processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
48
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
49
- MODEL_ID_QWEN,
50
- trust_remote_code=True,
51
- torch_dtype=torch.float16
52
- ).to("cuda").eval()
53
-
54
- # Load Orpheus TTS model and SNAC for TTS synthesis
55
- print("Loading SNAC model...")
56
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
57
- snac_model = snac_model.to(tts_device)
58
-
59
- tts_model_name = "canopylabs/orpheus-3b-0.1-ft"
60
- # Download only model config and safetensors
61
- snapshot_download(
62
- repo_id=tts_model_name,
63
- allow_patterns=[
64
- "config.json",
65
- "*.safetensors",
66
- "model.safetensors.index.json",
67
- ],
68
- ignore_patterns=[
69
- "optimizer.pt",
70
- "pytorch_model.bin",
71
- "training_args.bin",
72
- "scheduler.pt",
73
- "tokenizer.json",
74
- "tokenizer_config.json",
75
- "special_tokens_map.json",
76
- "vocab.json",
77
- "merges.txt",
78
- "tokenizer.*"
79
- ]
80
- )
81
- orpheus_tts_model = AutoModelForCausalLM.from_pretrained(tts_model_name, torch_dtype=torch.bfloat16)
82
- orpheus_tts_model.to(tts_device)
83
- orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
84
- print(f"Orpheus TTS model loaded to {tts_device}")
85
-
86
- # Some global parameters for chat responses
87
- MAX_MAX_NEW_TOKENS = 2048
88
- DEFAULT_MAX_NEW_TOKENS = 1024
89
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
-
91
- # (Image generation related code has been fully removed.)
92
-
93
- MAX_SEED = np.iinfo(np.int32).max
94
-
95
- # Utility functions
96
- def save_image(img: Image.Image) -> str:
97
- unique_name = str(uuid.uuid4()) + ".png"
98
- img.save(unique_name)
99
- return unique_name
100
-
101
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
102
- if randomize_seed:
103
- seed = random.randint(0, MAX_SEED)
104
- return seed
105
-
106
- def progress_bar_html(label: str) -> str:
107
- return f'''
108
- <div style="display: flex; align-items: center;">
109
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
110
- <div style="width: 110px; height: 5px; background-color: #FFA07A; border-radius: 2px; overflow: hidden;">
111
- <div style="width: 100%; height: 100%; background-color: #FF4500; animation: loading 1.5s linear infinite;"></div>
112
- </div>
113
- </div>
114
- <style>
115
- @keyframes loading {{
116
- 0% {{ transform: translateX(-100%); }}
117
- 100% {{ transform: translateX(100%); }}
118
- }}
119
- </style>
120
- '''
121
-
122
- def downsample_video(video_path):
123
- vidcap = cv2.VideoCapture(video_path)
124
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
125
- fps = vidcap.get(cv2.CAP_PROP_FPS)
126
- frames = []
127
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
128
- for i in frame_indices:
129
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
130
- success, image = vidcap.read()
131
- if success:
132
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133
- pil_image = Image.fromarray(image)
134
- timestamp = round(i / fps, 2)
135
- frames.append((pil_image, timestamp))
136
- vidcap.release()
137
- return frames
138
-
139
- def clean_chat_history(chat_history):
140
- cleaned = []
141
- for msg in chat_history:
142
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
143
- cleaned.append(msg)
144
- return cleaned
145
-
146
- # New TTS functions (SNAC/Orpheus pipeline)
147
- def process_prompt(prompt, voice, tokenizer, device):
148
- prompt = f"{voice}: {prompt}"
149
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
150
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
151
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End markers
152
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
153
- attention_mask = torch.ones_like(modified_input_ids)
154
- return modified_input_ids.to(device), attention_mask.to(device)
155
-
156
- def parse_output(generated_ids):
157
- token_to_find = 128257
158
- token_to_remove = 128258
159
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
160
- if len(token_indices[1]) > 0:
161
- last_occurrence_idx = token_indices[1][-1].item()
162
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
163
- else:
164
- cropped_tensor = generated_ids
165
- processed_rows = []
166
- for row in cropped_tensor:
167
- masked_row = row[row != token_to_remove]
168
- processed_rows.append(masked_row)
169
- code_lists = []
170
- for row in processed_rows:
171
- row_length = row.size(0)
172
- new_length = (row_length // 7) * 7
173
- trimmed_row = row[:new_length]
174
- trimmed_row = [t - 128266 for t in trimmed_row]
175
- code_lists.append(trimmed_row)
176
- return code_lists[0]
177
-
178
- def redistribute_codes(code_list, snac_model):
179
- device = next(snac_model.parameters()).device
180
- layer_1 = []
181
- layer_2 = []
182
- layer_3 = []
183
- for i in range((len(code_list)+1)//7):
184
- layer_1.append(code_list[7*i])
185
- layer_2.append(code_list[7*i+1]-4096)
186
- layer_3.append(code_list[7*i+2]-(2*4096))
187
- layer_3.append(code_list[7*i+3]-(3*4096))
188
- layer_2.append(code_list[7*i+4]-(4*4096))
189
- layer_3.append(code_list[7*i+5]-(5*4096))
190
- layer_3.append(code_list[7*i+6]-(6*4096))
191
- codes = [
192
- torch.tensor(layer_1, device=device).unsqueeze(0),
193
- torch.tensor(layer_2, device=device).unsqueeze(0),
194
- torch.tensor(layer_3, device=device).unsqueeze(0)
195
- ]
196
- audio_hat = snac_model.decode(codes)
197
- return audio_hat.detach().squeeze().cpu().numpy()
198
 
199
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens):
200
- if not text.strip():
201
- return None
202
  try:
203
- # Removed in-function progress calls to maintain UI consistency.
204
- input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
- with torch.no_grad():
206
- generated_ids = orpheus_tts_model.generate(
207
- input_ids=input_ids,
208
- attention_mask=attention_mask,
209
- max_new_tokens=max_new_tokens,
210
- do_sample=True,
211
- temperature=temperature,
212
- top_p=top_p,
213
- repetition_penalty=repetition_penalty,
214
- num_return_sequences=1,
215
- eos_token_id=128258,
216
- )
217
- code_list = parse_output(generated_ids)
218
- audio_samples = redistribute_codes(code_list, snac_model)
219
- return (24000, audio_samples)
220
  except Exception as e:
221
- print(f"Error generating speech: {e}")
222
- return None
223
 
224
- # Main generate function for the chat interface
225
  @spaces.GPU
226
- def generate(
227
- input_dict: dict,
228
- chat_history: list[dict],
229
- max_new_tokens: int = 1024,
230
- temperature: float = 0.6,
231
- top_p: float = 0.9,
232
- top_k: int = 50,
233
- repetition_penalty: float = 1.2,
234
- ):
235
- """
236
- Generates chatbot responses with support for multimodal input, video processing,
237
- TTS, and LLM-augmented TTS.
238
-
239
- Trigger commands:
240
- - "@video-infer": process video.
241
- - "@<voice>-tts": directly convert text to speech.
242
- - "@<voice>-llm": infer with the DeepHermes Llama model then convert to speech.
243
- """
244
- text = input_dict["text"]
245
- files = input_dict.get("files", [])
246
- lower_text = text.strip().lower()
247
-
248
- # Branch for video processing.
249
- if lower_text.startswith("@video-infer"):
250
- prompt = text[len("@video-infer"):].strip()
251
- if files:
252
- video_path = files[0]
253
- frames = downsample_video(video_path)
254
- messages = [
255
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
256
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
257
- ]
258
- for frame in frames:
259
- image, timestamp = frame
260
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
261
- image.save(image_path)
262
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
263
- messages[1]["content"].append({"type": "image", "url": image_path})
264
- else:
265
- messages = [
266
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
267
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
268
- ]
269
- inputs = processor.apply_chat_template(
270
- messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
271
- ).to("cuda")
272
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
273
- generation_kwargs = {
274
- **inputs,
275
- "streamer": streamer,
276
- "max_new_tokens": max_new_tokens,
277
- "do_sample": True,
278
- "temperature": temperature,
279
- "top_p": top_p,
280
- "top_k": top_k,
281
- "repetition_penalty": repetition_penalty,
282
- }
283
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
284
- thread.start()
285
- buffer = ""
286
- yield progress_bar_html("Processing video with Qwen2VL")
287
- for new_text in streamer:
288
- buffer += new_text.replace("<|im_end|>", "")
289
- time.sleep(0.01)
290
- yield buffer
291
- return
292
-
293
- # Define TTS and LLM tag mappings.
294
- tts_tags = {"@tara-tts": "tara", "@dan-tts": "dan", "@josh-tts": "josh", "@emma-tts": "emma"}
295
- llm_tags = {"@tara-llm": "tara", "@dan-llm": "dan", "@josh-llm": "josh", "@emma-llm": "emma"}
296
-
297
- # Branch for direct TTS (no LLM inference).
298
- for tag, voice in tts_tags.items():
299
- if lower_text.startswith(tag):
300
- text = text[len(tag):].strip()
301
- yield progress_bar_html("Processing with Orpheus")
302
- audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens)
303
- yield gr.Audio(audio_output, autoplay=True)
304
- return
305
-
306
- # Branch for LLM-augmented TTS.
307
- for tag, voice in llm_tags.items():
308
- if lower_text.startswith(tag):
309
- text = text[len(tag):].strip()
310
- conversation = [{"role": "user", "content": text}]
311
- input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
312
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
313
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
314
- input_ids = input_ids.to(hermes_llm_model.device)
315
- streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
316
- generation_kwargs = {
317
- "input_ids": input_ids,
318
- "streamer": streamer,
319
- "max_new_tokens": max_new_tokens,
320
- "do_sample": True,
321
- "top_p": top_p,
322
- "top_k": 50,
323
- "temperature": temperature,
324
- "num_beams": 1,
325
- "repetition_penalty": repetition_penalty,
326
- }
327
- t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
328
- t.start()
329
- outputs = []
330
- for new_text in streamer:
331
- outputs.append(new_text)
332
- final_response = "".join(outputs)
333
- yield progress_bar_html("Processing with Orpheus")
334
- audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max_new_tokens)
335
- yield gr.Audio(audio_output, autoplay=True)
336
- return
337
-
338
- # Default branch for regular chat (text and multimodal without TTS).
339
- conversation = clean_chat_history(chat_history)
340
- conversation.append({"role": "user", "content": text})
341
- # If files are provided, only non-image files (e.g. video) are processed via Qwen2VL.
342
- if files:
343
- # Process files using the processor (this branch no longer handles image generation)
344
- if len(files) > 1:
345
- inputs_list = [load_image(image) for image in files]
346
- elif len(files) == 1:
347
- inputs_list = [load_image(files[0])]
348
  else:
349
- inputs_list = []
350
- messages = [{
 
 
 
 
 
351
  "role": "user",
352
  "content": [
353
- *[{"type": "image", "image": img} for img in inputs_list],
354
- {"type": "text", "text": text},
355
- ]
356
- }]
357
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
358
- inputs = processor(text=[prompt_full], images=inputs_list, return_tensors="pt", padding=True).to("cuda")
359
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
361
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
362
- thread.start()
363
- buffer = ""
364
- yield progress_bar_html("Processing with Qwen2VL")
365
- for new_text in streamer:
366
- buffer += new_text.replace("<|im_end|>", "")
367
- time.sleep(0.01)
368
- yield buffer
369
- else:
370
- input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
371
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
372
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
373
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
374
- input_ids = input_ids.to(hermes_llm_model.device)
375
- streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
376
- generation_kwargs = {
377
- "input_ids": input_ids,
378
- "streamer": streamer,
379
- "max_new_tokens": max_new_tokens,
380
- "do_sample": True,
381
- "top_p": top_p,
382
- "top_k": top_k,
383
- "temperature": temperature,
384
- "num_beams": 1,
385
- "repetition_penalty": repetition_penalty,
386
  }
387
- t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
388
- t.start()
389
- outputs = []
390
- yield progress_bar_html("Processing with DeepHermes LLM")
391
- for new_text in streamer:
392
- outputs.append(new_text)
393
- yield "".join(outputs)
394
- final_response = "".join(outputs)
395
- yield final_response
396
-
397
- # Gradio Interface
398
- demo = gr.ChatInterface(
399
- fn=generate,
400
- additional_inputs=[
401
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
402
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
403
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
404
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
405
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
406
- ],
407
- examples=[
408
- ["@josh-tts Hey! I’m Josh, [gasp] and wow, did I just surprise you with my realistic voice?"],
409
- ["@dan-llm Explain the General Relativity theorem in short"],
410
- ["@emma-tts Hey, I’m Emma, [sigh] and yes, I can talk just like a person… even when I’m tired."],
411
- ["@josh-llm What causes rainbows to form?"],
412
- ["@dan-tts Yo, I’m Dan, [groan] and yes, I can even sound annoyed if I have to."],
413
- ["Write python program for array rotation"],
414
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
415
- ["@tara-tts Hey there, my name is Tara, [laugh] and I’m a speech generation model that can sound just like you!"],
416
- ["@tara-llm Who is Nikola Tesla, and why did he die?"],
417
- ["@emma-llm Explain the causes of rainbows"],
418
- [{"text": "@video-infer Summarize the event in video", "files": ["examples/sky.mp4"]}],
419
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
420
- ],
421
- cache_examples=False,
422
- type="messages",
423
- description="# **Orpheus Edge** `voice: tara, dan, emma, josh` \n `emotion: <laugh>, <chuckle>, <sigh>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>. Use @video-infer, orpheus: @<voice>-tts, or @<voice>-llm triggers llm response`",
424
- fill_height=True,
425
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ Use @tara-tts/@dan-tts for direct TTS or @tara-llm/@dan-llm for LLM+TTS, etc."),
426
- stop_btn="Stop Generation",
427
- multimodal=True,
428
- )
429
-
430
- if __name__ == "__main__":
431
- demo.queue(max_size=20).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
  import torch
 
6
  from PIL import Image
7
+ import os
8
+ import uuid
9
+ import io
10
+ from threading import Thread
11
+ from reportlab.lib.pagesizes import A4
12
+ from reportlab.lib.styles import getSampleStyleSheet
13
+ from reportlab.lib import colors
14
+ from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
15
+ from reportlab.lib.units import inch
16
+ from reportlab.pdfbase import pdfmetrics
17
+ from reportlab.pdfbase.ttfonts import TTFont
18
+ import docx
19
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
20
+
21
+ # Define model options
22
+ MODEL_OPTIONS = {
23
+ "Qwen2VL Base": "Qwen/Qwen2-VL-2B-Instruct",
24
+ "Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
25
+ "Math Prase": "prithivMLmods/Qwen2-VL-Math-Prase-2B-Instruct",
26
+ "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
27
+ }
28
+
29
+ # Preload models and processors into CUDA
30
+ models = {}
31
+ processors = {}
32
+ for name, model_id in MODEL_OPTIONS.items():
33
+ print(f"Loading {name}...")
34
+ models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
35
+ model_id,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float16
38
+ ).to("cuda").eval()
39
+ processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
40
+
41
+ image_extensions = Image.registered_extensions()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def identify_and_save_blob(blob_path):
44
+ """Identifies if the blob is an image and saves it."""
 
45
  try:
46
+ with open(blob_path, 'rb') as file:
47
+ blob_content = file.read()
48
+ try:
49
+ Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
50
+ extension = ".png" # Default to PNG for saving
51
+ media_type = "image"
52
+ except (IOError, SyntaxError):
53
+ raise ValueError("Unsupported media type. Please upload a valid image.")
54
+
55
+ filename = f"temp_{uuid.uuid4()}_media{extension}"
56
+ with open(filename, "wb") as f:
57
+ f.write(blob_content)
58
+
59
+ return filename, media_type
60
+
61
+ except FileNotFoundError:
62
+ raise ValueError(f"The file {blob_path} was not found.")
63
  except Exception as e:
64
+ raise ValueError(f"An error occurred while processing the file: {e}")
 
65
 
 
66
  @spaces.GPU
67
+ def qwen_inference(model_name, media_input, text_input=None):
68
+ """Handles inference for the selected model."""
69
+ model = models[model_name]
70
+ processor = processors[model_name]
71
+
72
+ if isinstance(media_input, str):
73
+ media_path = media_input
74
+ if media_path.endswith(tuple([i for i in image_extensions.keys()])):
75
+ media_type = "image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
+ try:
78
+ media_path, media_type = identify_and_save_blob(media_input)
79
+ except Exception as e:
80
+ raise ValueError("Unsupported media type. Please upload a valid image.")
81
+
82
+ messages = [
83
+ {
84
  "role": "user",
85
  "content": [
86
+ {
87
+ "type": media_type,
88
+ media_type: media_path
89
+ },
90
+ {"type": "text", "text": text_input},
91
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  }
93
+ ]
94
+
95
+ text = processor.apply_chat_template(
96
+ messages, tokenize=False, add_generation_prompt=True
97
+ )
98
+ image_inputs, _ = process_vision_info(messages)
99
+ inputs = processor(
100
+ text=[text],
101
+ images=image_inputs,
102
+ padding=True,
103
+ return_tensors="pt",
104
+ ).to("cuda")
105
+
106
+ streamer = TextIteratorStreamer(
107
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
108
+ )
109
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
110
+
111
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
+ thread.start()
113
+
114
+ buffer = ""
115
+ for new_text in streamer:
116
+ buffer += new_text
117
+ # Remove <|im_end|> or similar tokens from the output
118
+ buffer = buffer.replace("<|im_end|>", "")
119
+ yield buffer
120
+
121
+ def format_plain_text(output_text):
122
+ """Formats the output text as plain text without LaTeX delimiters."""
123
+ # Remove LaTeX delimiters and convert to plain text
124
+ plain_text = output_text.replace("\\(", "").replace("\\)", "").replace("\\[", "").replace("\\]", "")
125
+ return plain_text
126
+
127
+ def generate_document(media_path, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size):
128
+ """Generates a document with the input image and plain text output."""
129
+ plain_text = format_plain_text(output_text)
130
+ if file_format == "pdf":
131
+ return generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
132
+ elif file_format == "docx":
133
+ return generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
134
+
135
+ def generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
136
+ """Generates a PDF document."""
137
+ filename = f"output_{uuid.uuid4()}.pdf"
138
+ doc = SimpleDocTemplate(
139
+ filename,
140
+ pagesize=A4,
141
+ rightMargin=inch,
142
+ leftMargin=inch,
143
+ topMargin=inch,
144
+ bottomMargin=inch
145
+ )
146
+ styles = getSampleStyleSheet()
147
+ styles["Normal"].fontName = font_choice
148
+ styles["Normal"].fontSize = int(font_size)
149
+ styles["Normal"].leading = int(font_size) * line_spacing
150
+ styles["Normal"].alignment = {
151
+ "Left": 0,
152
+ "Center": 1,
153
+ "Right": 2,
154
+ "Justified": 4
155
+ }[alignment]
156
+
157
+ # Register font
158
+ font_path = f"font/{font_choice}"
159
+ pdfmetrics.registerFont(TTFont(font_choice, font_path))
160
+
161
+ story = []
162
+
163
+ # Add image with size adjustment
164
+ image_sizes = {
165
+ "Small": (200, 200),
166
+ "Medium": (400, 400),
167
+ "Large": (600, 600)
168
+ }
169
+ img = RLImage(media_path, width=image_sizes[image_size][0], height=image_sizes[image_size][1])
170
+ story.append(img)
171
+ story.append(Spacer(1, 12))
172
+
173
+ # Add plain text output
174
+ text = Paragraph(plain_text, styles["Normal"])
175
+ story.append(text)
176
+
177
+ doc.build(story)
178
+ return filename
179
+
180
+ def generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
181
+ """Generates a DOCX document."""
182
+ filename = f"output_{uuid.uuid4()}.docx"
183
+ doc = docx.Document()
184
+
185
+ # Add image with size adjustment
186
+ image_sizes = {
187
+ "Small": docx.shared.Inches(2),
188
+ "Medium": docx.shared.Inches(4),
189
+ "Large": docx.shared.Inches(6)
190
+ }
191
+ doc.add_picture(media_path, width=image_sizes[image_size])
192
+ doc.add_paragraph()
193
+
194
+ # Add plain text output
195
+ paragraph = doc.add_paragraph()
196
+ paragraph.paragraph_format.line_spacing = line_spacing
197
+ paragraph.paragraph_format.alignment = {
198
+ "Left": WD_ALIGN_PARAGRAPH.LEFT,
199
+ "Center": WD_ALIGN_PARAGRAPH.CENTER,
200
+ "Right": WD_ALIGN_PARAGRAPH.RIGHT,
201
+ "Justified": WD_ALIGN_PARAGRAPH.JUSTIFY
202
+ }[alignment]
203
+ run = paragraph.add_run(plain_text)
204
+ run.font.name = font_choice
205
+ run.font.size = docx.shared.Pt(int(font_size))
206
+
207
+ doc.save(filename)
208
+ return filename
209
+
210
+ # CSS for output styling
211
+ css = """
212
+ #output {
213
+ height: 500px;
214
+ overflow: auto;
215
+ border: 1px solid #ccc;
216
+ }
217
+ .submit-btn {
218
+ background-color: #cf3434 !important;
219
+ color: white !important;
220
+ }
221
+ .submit-btn:hover {
222
+ background-color: #ff2323 !important;
223
+ }
224
+ .download-btn {
225
+ background-color: #35a6d6 !important;
226
+ color: white !important;
227
+ }
228
+ .download-btn:hover {
229
+ background-color: #22bcff !important;
230
+ }
231
+ """
232
+
233
+ # Gradio app setup
234
+ with gr.Blocks(css=css) as demo:
235
+ gr.Markdown("# Qwen2VL Models: Vision and Language Processing")
236
+
237
+ with gr.Tab(label="Image Input"):
238
+
239
+ with gr.Row():
240
+ with gr.Column():
241
+ model_choice = gr.Dropdown(
242
+ label="Model Selection",
243
+ choices=list(MODEL_OPTIONS.keys()),
244
+ value="Latex OCR"
245
+ )
246
+ input_media = gr.File(
247
+ label="Upload Image", type="filepath"
248
+ )
249
+ text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
250
+ submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
251
+
252
+ with gr.Column():
253
+ output_text = gr.Textbox(label="Output Text", lines=10)
254
+ plain_text_output = gr.Textbox(label="Standardized Plain Text", lines=10)
255
+
256
+ submit_btn.click(
257
+ qwen_inference, [model_choice, input_media, text_input], [output_text]
258
+ ).then(
259
+ lambda output_text: format_plain_text(output_text), [output_text], [plain_text_output]
260
+ )
261
+
262
+ # Add examples directly usable by clicking
263
+ with gr.Row():
264
+ gr.Examples(
265
+ examples=[
266
+ ["examples/1.png", "summarize the letter", "Text Analogy Ocrtest"],
267
+ ["examples/2.jpg", "Summarize the full image in detail", "Latex OCR"],
268
+ ["examples/3.png", "Describe the photo", "Qwen2VL Base"],
269
+ ["examples/4.png", "summarize and solve the problem", "Math Prase"],
270
+ ],
271
+ inputs=[input_media, text_input, model_choice],
272
+ outputs=[output_text, plain_text_output],
273
+ fn=lambda img, question, model: qwen_inference(model, img, question),
274
+ cache_examples=False,
275
+ )
276
+
277
+ with gr.Row():
278
+ with gr.Column():
279
+ line_spacing = gr.Dropdown(
280
+ choices=[0.5, 1.0, 1.15, 1.5, 2.0, 2.5, 3.0],
281
+ value=1.5,
282
+ label="Line Spacing"
283
+ )
284
+ font_size = gr.Dropdown(
285
+ choices=["8", "10", "12", "14", "16", "18", "20", "22", "24"],
286
+ value="18",
287
+ label="Font Size"
288
+ )
289
+ font_choice = gr.Dropdown(
290
+ choices=[
291
+ "DejaVuMathTeXGyre.ttf",
292
+ "FiraCode-Medium.ttf",
293
+ "InputMono-Light.ttf",
294
+ "JetBrainsMono-Thin.ttf",
295
+ "ProggyCrossed Regular Mac.ttf",
296
+ "SourceCodePro-Black.ttf",
297
+ "arial.ttf",
298
+ "calibri.ttf",
299
+ "mukta-malar-extralight.ttf",
300
+ "noto-sans-arabic-medium.ttf",
301
+ "times new roman.ttf",
302
+ "ANGSA.ttf",
303
+ "Book-Antiqua.ttf",
304
+ "CONSOLA.TTF",
305
+ "COOPBL.TTF",
306
+ "Rockwell-Bold.ttf",
307
+ "Candara Light.TTF",
308
+ "Carlito-Regular.ttf Carlito-Regular.ttf",
309
+ "Castellar.ttf",
310
+ "Courier New.ttf",
311
+ "LSANS.TTF",
312
+ "Lucida Bright Regular.ttf",
313
+ "TRTempusSansITC.ttf",
314
+ "Verdana.ttf",
315
+ "bell-mt.ttf",
316
+ "eras-itc-light.ttf",
317
+ "fonnts.com-aptos-light.ttf",
318
+ "georgia.ttf",
319
+ "segoeuithis.ttf",
320
+ "youyuan.TTF",
321
+ "TfPonetoneExpanded-7BJZA.ttf",
322
+ ],
323
+ value="youyuan.TTF",
324
+ label="Font Choice"
325
+ )
326
+ alignment = gr.Dropdown(
327
+ choices=["Left", "Center", "Right", "Justified"],
328
+ value="Justified",
329
+ label="Text Alignment"
330
+ )
331
+ image_size = gr.Dropdown(
332
+ choices=["Small", "Medium", "Large"],
333
+ value="Small",
334
+ label="Image Size"
335
+ )
336
+ file_format = gr.Radio(["pdf", "docx"], label="File Format", value="pdf")
337
+ get_document_btn = gr.Button(value="Get Document", elem_classes="download-btn")
338
+
339
+ get_document_btn.click(
340
+ generate_document, [input_media, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size], gr.File(label="Download Document")
341
+ )
342
+
343
+ demo.launch(debug=True)