prithivMLmods commited on
Commit
e354e80
·
verified ·
1 Parent(s): 35bb999

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +492 -285
app.py CHANGED
@@ -1,302 +1,509 @@
1
- import subprocess
2
- subprocess.run(
3
- 'pip install flash-attn==2.7.0.post2 --no-build-isolation',
4
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
- shell=True
6
- )
7
- subprocess.run(
8
- 'pip install transformers',
9
- shell=True
10
- )
11
-
12
-
13
- import spaces
14
  import os
15
- import re
16
- import logging
17
- from typing import List
 
 
18
  from threading import Thread
19
- import base64
20
 
21
- import torch
22
  import gradio as gr
23
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # ----------------------------------------------------------------------
26
- # 1. Setup Model & Tokenizer
27
- # ----------------------------------------------------------------------
28
- model_name = 'prithivMLmods/Raptor-X6' # Change as needed
29
- use_thread = True # Generation happens in a background thread
30
 
 
 
 
 
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
- model_name,
 
33
  torch_dtype=torch.bfloat16,
34
- trust_remote_code=True
35
- ).to("cuda")
36
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
37
-
38
- logging.getLogger("httpx").setLevel(logging.WARNING)
39
- logging.basicConfig(level=logging.INFO)
40
- logger = logging.getLogger(__name__)
41
-
42
- # ----------------------------------------------------------------------
43
- # 2. Two-Phase Prompt Templates
44
- # ----------------------------------------------------------------------
45
- s1_inference_prompt_think_only = """<|im_start|>user
46
- {question}<|im_end|>
47
- <|im_start|>assistant
48
- <|im_start|>think
49
- """
50
-
51
- # ----------------------------------------------------------------------
52
- # 3. Generation Parameter Setup
53
- # ----------------------------------------------------------------------
54
- THINK_MAX_NEW_TOKENS = 12000
55
- ANSWER_MAX_NEW_TOKENS = 12000
56
-
57
- def initialize_gen_kwargs():
58
- return {
59
- "max_new_tokens": 1024, # default; will be overwritten per phase
60
- "do_sample": True,
61
- "temperature": 0.7,
62
- "top_p": 0.9,
63
- "repetition_penalty": 1.05,
64
- # "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping
65
- "pad_token_id": tokenizer.pad_token_id,
66
- "use_cache": True,
67
- "streamer": None # dynamically added
68
- }
69
-
70
- # ----------------------------------------------------------------------
71
- # 4. Helper to submit chat
72
- # ----------------------------------------------------------------------
73
- def submit_chat(chatbot, text_input):
74
- if not text_input.strip():
75
- return chatbot, ""
76
- response = ""
77
- chatbot.append((text_input, response))
78
- return chatbot, ""
79
-
80
- # ----------------------------------------------------------------------
81
- # 5. Artifacts Handling
82
- # We parse code from the final answer and display it in an iframe
83
- # ----------------------------------------------------------------------
84
- def extract_html_code_block(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
- Look for a ```html ... ``` block in the text.
87
- If found, return only that block content.
88
- Otherwise, return the entire text.
89
  """
90
- pattern = r'```html\s*(.*?)\s*```'
91
- match = re.search(pattern, text, re.DOTALL)
92
- if match:
93
- return match.group(1).strip()
94
- else:
95
- return text.strip()
96
-
97
- def send_to_sandbox(html_code: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  """
99
- Convert the code to a data URI iframe so it can be rendered
100
- inside Gradio HTML component.
101
  """
102
- encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
103
- data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
104
- return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
105
-
106
- # ----------------------------------------------------------------------
107
- # 6. The Two-Phase Streaming Inference
108
- # - Phase 1: "think" (chain-of-thought)
109
- # - Phase 2: "answer"
110
- # ----------------------------------------------------------------------
111
- @spaces.GPU
112
- def ovis_chat(chatbot: List[List[str]]):
113
- # Phase 1: chain-of-thought
114
- last_query = chatbot[-1][0]
115
- formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
116
- input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
117
- attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device)
118
-
119
- think_inputs = {
120
- "input_ids": input_ids_think,
121
- "attention_mask": attention_mask_think
122
- }
123
- gen_kwargs_think = initialize_gen_kwargs()
124
- gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
125
- think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
126
- gen_kwargs_think["streamer"] = think_streamer
127
-
128
- full_think = ""
129
- with torch.inference_mode():
130
- thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
131
- thread_think.start()
132
- for new_text in think_streamer:
133
- full_think += new_text
134
- display_text = f"<|im_start|>think\n{full_think.strip()}"
135
- chatbot[-1][1] = display_text
136
- yield chatbot, "" # second return is artifact placeholder
137
- thread_think.join()
138
-
139
- # Phase 2: answer
140
- new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
141
- input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device)
142
- attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device)
143
-
144
- answer_inputs = {
145
- "input_ids": input_ids_answer,
146
- "attention_mask": attention_mask_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  }
148
- gen_kwargs_answer = initialize_gen_kwargs()
149
- gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS
150
- answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
151
- gen_kwargs_answer["streamer"] = answer_streamer
152
-
153
- full_answer = ""
154
- with torch.inference_mode():
155
- thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
156
- thread_answer.start()
157
- for new_text in answer_streamer:
158
- full_answer += new_text
159
- display_text = (
160
- f"<|im_start|>think\n{full_think.strip()}\n\n"
161
- f"<|im_start|>answer\n{full_answer.strip()}"
162
- )
163
- chatbot[-1][1] = display_text
164
- yield chatbot, ""
165
- thread_answer.join()
166
-
167
- log_conversation(chatbot)
168
-
169
- # Once final answer is complete, parse out HTML code block and
170
- # return it as an artifact (iframe).
171
- html_code = extract_html_code_block(full_answer)
172
- sandbox_iframe = send_to_sandbox(html_code)
173
- yield chatbot, sandbox_iframe
174
-
175
- # ----------------------------------------------------------------------
176
- # 7. Logging and Clearing
177
- # ----------------------------------------------------------------------
178
- def log_conversation(chatbot: List[List[str]]):
179
- logger.info("[CONVERSATION]")
180
- for i, (query, response) in enumerate(chatbot, 1):
181
- logger.info(f"Q{i}: {query}\nA{i}: {response}")
182
-
183
- def clear_chat():
184
- return [], "", ""
185
-
186
- # ----------------------------------------------------------------------
187
- # 8. Gradio UI Setup
188
- # ----------------------------------------------------------------------
189
- css_code = """
190
- .left_header {
191
- display: flex;
192
- flex-direction: column;
193
- justify-content: center;
194
- align-items: center;
195
- }
196
-
197
- .right_panel {
198
- margin-top: 16px;
199
- border: 1px solid #BFBFC4;
200
- border-radius: 8px;
201
- overflow: hidden;
202
- }
203
-
204
- .render_header {
205
- height: 30px;
206
- width: 100%;
207
- padding: 5px 16px;
208
- background-color: #f5f5f5;
209
- }
210
-
211
- .header_btn {
212
- display: inline-block;
213
- height: 10px;
214
- width: 10px;
215
- border-radius: 50%;
216
- margin-right: 4px;
217
- }
218
-
219
- .render_header > .header_btn:nth-child(1) {
220
- background-color: #f5222d;
221
- }
222
-
223
- .render_header > .header_btn:nth-child(2) {
224
- background-color: #faad14;
225
- }
226
- .render_header > .header_btn:nth-child(3) {
227
- background-color: #52c41a;
228
- }
229
-
230
- .right_content {
231
- height: 920px;
232
- display: flex;
233
- flex-direction: column;
234
- justify-content: center;
235
- align-items: center;
236
- }
237
 
238
- .html_content {
239
- width: 100%;
240
- height: 920px;
241
- }
242
- """
243
-
244
- svg_content = """
245
- <svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
246
- <circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
247
- <path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
248
- <path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
249
- <path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
250
- </svg>
251
- """
252
-
253
- with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
254
- gr.HTML(f"""
255
- <div class="left_header" style="margin-bottom: 20px;">
256
- {svg_content}
257
- <h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
258
- <p>(Two-phase chain-of-thought with artifact extraction)</p>
259
- </div>
260
- """)
261
 
262
- with gr.Row():
263
- with gr.Column(scale=4):
264
- chatbot = gr.Chatbot(
265
- label="Chat",
266
- height=520,
267
- show_copy_button=True
268
- )
269
- with gr.Row():
270
- text_input = gr.Textbox(
271
- label="Prompt",
272
- placeholder="Enter your query...",
273
- lines=1
274
- )
275
- with gr.Row():
276
- submit_btn = gr.Button("Send", variant="primary")
277
- clear_btn = gr.Button("Clear", variant="secondary")
278
- with gr.Column(scale=6):
279
- gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
280
- artifact_html = gr.HTML(
281
- value="",
282
- elem_classes="html_content"
283
- )
284
-
285
- submit_btn.click(
286
- submit_chat, [chatbot, text_input], [chatbot, text_input]
287
- ).then(
288
- ovis_chat, [chatbot], [chatbot, artifact_html]
289
- )
290
-
291
- text_input.submit(
292
- submit_chat, [chatbot, text_input], [chatbot, text_input]
293
- ).then(
294
- ovis_chat, [chatbot], [chatbot, artifact_html]
295
- )
296
-
297
- clear_btn.click(
298
- clear_chat,
299
- outputs=[chatbot, text_input, artifact_html]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
301
 
302
- demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
  from threading import Thread
 
8
 
 
9
  import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import edge_tts
15
+ import cv2
16
+
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ TextIteratorStreamer,
21
+ Qwen2VLForConditionalGeneration,
22
+ AutoProcessor,
23
+ )
24
+ from transformers.image_utils import load_image
25
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
+ MAX_MAX_NEW_TOKENS = 2048
28
+ DEFAULT_MAX_NEW_TOKENS = 1024
29
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
30
 
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Load text-only model and tokenizer
34
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
  model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ device_map="auto",
39
  torch_dtype=torch.bfloat16,
40
+ )
41
+ model.eval()
42
+
43
+ # Updated TTS voices list (all voices)
44
+ TTS_VOICES = [
45
+ "af-ZA-AdriNeural",
46
+ "af-ZA-WillemNeural",
47
+ "am-ET-AmehaNeural",
48
+ "am-ET-MekdesNeural",
49
+ "ar-AE-FatimaNeural",
50
+ "ar-AE-HamdanNeural",
51
+ "ar-BH-LailaNeural",
52
+ "ar-BH-MajedNeural",
53
+ "ar-DZ-AminaNeural",
54
+ "ar-DZ-IsmaelNeural",
55
+ "ar-EG-SalmaNeural",
56
+ "ar-EG-OmarNeural",
57
+ "ar-IQ-LanaNeural",
58
+ "ar-IQ-BassamNeural",
59
+ "ar-JO-SanaNeural",
60
+ "ar-JO-TaimNeural",
61
+ "ar-KW-NouraNeural",
62
+ "ar-KW-FahedNeural",
63
+ "ar-LB-LaylaNeural",
64
+ "ar-LB-RamiNeural",
65
+ "ar-LY-ImanNeural",
66
+ "ar-LY-OmarNeural",
67
+ "ar-MA-MounaNeural",
68
+ "ar-MA-JamalNeural",
69
+ "ar-OM-AyshaNeural",
70
+ "ar-OM-AbdullahNeural",
71
+ "ar-QA-AmalNeural",
72
+ "ar-QA-MoazNeural",
73
+ "ar-SA-ZariyahNeural",
74
+ "ar-SA-HamedNeural",
75
+ "ar-SY-AmanyNeural",
76
+ "ar-SY-LaithNeural",
77
+ "ar-TN-ReemNeural",
78
+ "ar-TN-SeifNeural",
79
+ "ar-YE-MaryamNeural",
80
+ "ar-YE-SalehNeural",
81
+ "az-AZ-BabekNeural",
82
+ "az-AZ-BanuNeural",
83
+ "bg-BG-BorislavNeural",
84
+ "bg-BG-KalinaNeural",
85
+ "bn-BD-NabanitaNeural",
86
+ "bn-BD-PradeepNeural",
87
+ "bn-IN-TanishaNeural",
88
+ "bn-IN-SwapanNeural",
89
+ "bs-BA-GoranNeural",
90
+ "bs-BA-VesnaNeural",
91
+ "ca-ES-JoanaNeural",
92
+ "ca-ES-AlbaNeural",
93
+ "ca-ES-EnricNeural",
94
+ "cs-CZ-AntoninNeural",
95
+ "cs-CZ-VlastaNeural",
96
+ "cy-GB-NiaNeural",
97
+ "cy-GB-AledNeural",
98
+ "da-DK-ChristelNeural",
99
+ "da-DK-JeppeNeural",
100
+ "de-AT-IngridNeural",
101
+ "de-AT-JonasNeural",
102
+ "de-CH-LeniNeural",
103
+ "de-CH-JanNeural",
104
+ "de-DE-KatjaNeural",
105
+ "de-DE-ConradNeural",
106
+ "el-GR-AthinaNeural",
107
+ "el-GR-NestorasNeural",
108
+ "en-AU-AnnetteNeural",
109
+ "en-AU-MichaelNeural",
110
+ "en-CA-ClaraNeural",
111
+ "en-CA-LiamNeural",
112
+ "en-GB-SoniaNeural",
113
+ "en-GB-RyanNeural",
114
+ "en-GH-EsiNeural",
115
+ "en-GH-KwameNeural",
116
+ "en-HK-YanNeural",
117
+ "en-HK-TrevorNeural",
118
+ "en-IE-EmilyNeural",
119
+ "en-IE-ConnorNeural",
120
+ "en-IN-NeerjaNeural",
121
+ "en-IN-PrabhasNeural",
122
+ "en-KE-ChantelleNeural",
123
+ "en-KE-ChilembaNeural",
124
+ "en-NG-EzinneNeural",
125
+ "en-NG-AbechiNeural",
126
+ "en-NZ-MollyNeural",
127
+ "en-NZ-MitchellNeural",
128
+ "en-PH-RosaNeural",
129
+ "en-PH-JamesNeural",
130
+ "en-SG-LunaNeural",
131
+ "en-SG-WayneNeural",
132
+ "en-TZ-ImaniNeural",
133
+ "en-TZ-DaudiNeural",
134
+ "en-US-JennyNeural",
135
+ "en-US-GuyNeural",
136
+ "en-ZA-LeahNeural",
137
+ "en-ZA-LukeNeural",
138
+ "es-AR-ElenaNeural",
139
+ "es-AR-TomasNeural",
140
+ "es-BO-SofiaNeural",
141
+ "es-BO-MarceloNeural",
142
+ "es-CL-CatalinaNeural",
143
+ "es-CL-LorenzoNeural",
144
+ "es-CO-SalomeNeural",
145
+ "es-CO-GonzaloNeural",
146
+ "es-CR-MariaNeural",
147
+ "es-CR-JuanNeural",
148
+ "es-CU-BelkysNeural",
149
+ "es-CU-ManuelNeural",
150
+ "es-DO-RamonaNeural",
151
+ "es-DO-EmilioNeural",
152
+ "es-EC-AndreaNeural",
153
+ "es-EC-LuisNeural",
154
+ "es-ES-ElviraNeural",
155
+ "es-ES-AlvaroNeural",
156
+ "es-GQ-TeresaNeural",
157
+ "es-GQ-JavierNeural",
158
+ "es-GT-MartaNeural",
159
+ "es-GT-AndresNeural",
160
+ "es-HN-KarlaNeural",
161
+ "es-HN-CarlosNeural",
162
+ "es-MX-DaliaNeural",
163
+ "es-MX-JorgeNeural",
164
+ "es-NI-YolandaNeural",
165
+ "es-NI-FedericoNeural",
166
+ "es-PA-MargaritaNeural",
167
+ "es-PA-RobertoNeural",
168
+ "es-PE-CamilaNeural",
169
+ "es-PE-AlexNeural",
170
+ "es-PR-KarinaNeural",
171
+ "es-PR-VictorNeural",
172
+ "es-PY-TaniaNeural",
173
+ "es-PY-MarioNeural",
174
+ "es-SV-LorenaNeural",
175
+ "es-SV-RodrigoNeural",
176
+ "es-US-SaraNeural",
177
+ "es-US-AlonsoNeural",
178
+ "es-UY-ValentinaNeural",
179
+ "es-UY-MateoNeural",
180
+ "es-VE-PaolaNeural",
181
+ "es-VE-SebastianNeural",
182
+ "et-EE-AnuNeural",
183
+ "et-EE-KertNeural",
184
+ "eu-ES-AinhoaNeural",
185
+ "eu-ES-AnderNeural",
186
+ "fa-IR-DilaraNeural",
187
+ "fa-IR-FaridNeural",
188
+ "fi-FI-NooraNeural",
189
+ "fi-FI-HarriNeural",
190
+ "fil-PH-BlessicaNeural",
191
+ "fil-PH-AngeloNeural",
192
+ "fr-BE-CharlineNeural",
193
+ "fr-BE-GerardNeural",
194
+ "fr-CA-SylvieNeural",
195
+ "fr-CA-AntoineNeural",
196
+ "fr-CH-ArianeNeural",
197
+ "fr-CH-GuillaumeNeural",
198
+ "fr-FR-DeniseNeural",
199
+ "fr-FR-HenriNeural",
200
+ "ga-IE-OrlaNeural",
201
+ "ga-IE-ColmNeural",
202
+ "gl-ES-SoniaNeural",
203
+ "gl-ES-XiaoqiangNeural",
204
+ "gu-IN-DhwaniNeural",
205
+ "gu-IN-NiranjanNeural",
206
+ "ha-NG-AishaNeural",
207
+ "ha-NG-YusufNeural",
208
+ "he-IL-HilaNeural",
209
+ "he-IL-AvriNeural",
210
+ "hi-IN-SwaraNeural",
211
+ "hi-IN-MadhurNeural",
212
+ "hr-HR-GabrijelaNeural",
213
+ "hr-HR-SreckoNeural",
214
+ "hu-HU-NoemiNeural",
215
+ "hu-HU-TamasNeural",
216
+ "hy-AM-AnushNeural",
217
+ "hy-AM-HaykNeural",
218
+ "id-ID-ArdiNeural",
219
+ "id-ID-GadisNeural",
220
+ "ig-NG-AdaNeural",
221
+ "ig-NG-EzeNeural",
222
+ "is-IS-GudrunNeural",
223
+ "is-IS-GunnarNeural",
224
+ "it-IT-ElsaNeural",
225
+ "it-IT-DiegoNeural",
226
+ "ja-JP-NanamiNeural",
227
+ "ja-JP-KeitaNeural",
228
+ "jv-ID-DianNeural",
229
+ "jv-ID-GustiNeural",
230
+ "ka-GE-EkaNeural",
231
+ # ... (truncated for brevity; include all voices as needed)
232
+ ]
233
+
234
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
235
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
236
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
237
+ MODEL_ID,
238
+ trust_remote_code=True,
239
+ torch_dtype=torch.float16
240
+ ).to("cuda").eval()
241
+
242
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
243
+ """Convert text to speech using Edge TTS and save as MP3"""
244
+ communicate = edge_tts.Communicate(text, voice)
245
+ await communicate.save(output_file)
246
+ return output_file
247
+
248
+ def clean_chat_history(chat_history):
249
  """
250
+ Filter out any chat entries whose "content" is not a string.
251
+ This helps prevent errors when concatenating previous messages.
 
252
  """
253
+ cleaned = []
254
+ for msg in chat_history:
255
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
256
+ cleaned.append(msg)
257
+ return cleaned
258
+
259
+ # Environment variables and parameters for Stable Diffusion XL (left in case needed in the future)
260
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
261
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
262
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
263
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
264
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
265
+
266
+ # Load the SDXL pipeline (not used in the current configuration)
267
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
268
+ MODEL_ID_SD,
269
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
270
+ use_safetensors=True,
271
+ add_watermarker=False,
272
+ ).to(device)
273
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
274
+ if torch.cuda.is_available():
275
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
276
+ if USE_TORCH_COMPILE:
277
+ sd_pipe.compile()
278
+ if ENABLE_CPU_OFFLOAD:
279
+ sd_pipe.enable_model_cpu_offload()
280
+
281
+ MAX_SEED = np.iinfo(np.int32).max
282
+
283
+ def save_image(img: Image.Image) -> str:
284
+ """Save a PIL image with a unique filename and return the path."""
285
+ unique_name = str(uuid.uuid4()) + ".png"
286
+ img.save(unique_name)
287
+ return unique_name
288
+
289
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
290
+ if randomize_seed:
291
+ seed = random.randint(0, MAX_SEED)
292
+ return seed
293
+
294
+ def progress_bar_html(label: str) -> str:
295
  """
296
+ Returns an HTML snippet for a thin progress bar with a label.
297
+ The progress bar is styled as a dark red animated bar.
298
  """
299
+ return f'''
300
+ <div style="display: flex; align-items: center;">
301
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
302
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
303
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
304
+ </div>
305
+ </div>
306
+ <style>
307
+ @keyframes loading {{
308
+ 0% {{ transform: translateX(-100%); }}
309
+ 100% {{ transform: translateX(100%); }}
310
+ }}
311
+ </style>
312
+ '''
313
+
314
+ def downsample_video(video_path):
315
+ """
316
+ Downsamples the video to 10 evenly spaced frames.
317
+ Each frame is returned as a PIL image along with its timestamp.
318
+ """
319
+ vidcap = cv2.VideoCapture(video_path)
320
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
321
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
322
+ frames = []
323
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
324
+ for i in frame_indices:
325
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
326
+ success, image = vidcap.read()
327
+ if success:
328
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
329
+ pil_image = Image.fromarray(image)
330
+ timestamp = round(i / fps, 2)
331
+ frames.append((pil_image, timestamp))
332
+ vidcap.release()
333
+ return frames
334
+
335
+ @spaces.GPU(duration=60, enable_queue=True)
336
+ def generate_image_fn(
337
+ prompt: str,
338
+ negative_prompt: str = "",
339
+ use_negative_prompt: bool = False,
340
+ seed: int = 1,
341
+ width: int = 1024,
342
+ height: int = 1024,
343
+ guidance_scale: float = 3,
344
+ num_inference_steps: int = 25,
345
+ randomize_seed: bool = False,
346
+ use_resolution_binning: bool = True,
347
+ num_images: int = 1,
348
+ progress=gr.Progress(track_tqdm=True),
349
+ ):
350
+ """(Image generation function is preserved but not called in the current configuration)"""
351
+ seed = int(randomize_seed_fn(seed, randomize_seed))
352
+ generator = torch.Generator(device=device).manual_seed(seed)
353
+ options = {
354
+ "prompt": [prompt] * num_images,
355
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
356
+ "width": width,
357
+ "height": height,
358
+ "guidance_scale": guidance_scale,
359
+ "num_inference_steps": num_inference_steps,
360
+ "generator": generator,
361
+ "output_type": "pil",
362
  }
363
+ if use_resolution_binning:
364
+ options["use_resolution_binning"] = True
365
+ images = []
366
+ for i in range(0, num_images, BATCH_SIZE):
367
+ batch_options = options.copy()
368
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
369
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
370
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
371
+ if device.type == "cuda":
372
+ with torch.autocast("cuda", dtype=torch.float16):
373
+ outputs = sd_pipe(**batch_options)
374
+ else:
375
+ outputs = sd_pipe(**batch_options)
376
+ images.extend(outputs.images)
377
+ image_paths = [save_image(img) for img in images]
378
+ return image_paths, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ @spaces.GPU
381
+ def generate(
382
+ input_dict: dict,
383
+ chat_history: list[dict],
384
+ max_new_tokens: int = 1024,
385
+ temperature: float = 0.6,
386
+ top_p: float = 0.9,
387
+ top_k: int = 50,
388
+ repetition_penalty: float = 1.2,
389
+ convert_to_speech: bool = False,
390
+ tts_rate: float = 1.0,
391
+ tts_voice: str = "en-US-JennyNeural",
392
+ ):
393
+ """
394
+ Generates chatbot responses with support for multimodal input and TTS conversion.
395
+ When files (images or videos) are provided, Qwen2VL is used.
396
+ Otherwise, the FastThink-0.5B text model is used.
397
+ After generating the response, if convert_to_speech is True the text is passed to the TTS function.
398
+ """
399
+ text = input_dict["text"].strip()
400
+ files = input_dict.get("files", [])
 
 
401
 
402
+ # Determine which branch to use: multimodal (if files provided) or text-only.
403
+ if files:
404
+ # Process uploaded files as images (or videos)
405
+ if len(files) > 1:
406
+ images = [load_image(image) for image in files]
407
+ else:
408
+ images = [load_image(files[0])]
409
+ messages = [{
410
+ "role": "user",
411
+ "content": [
412
+ *[{"type": "image", "image": image} for image in images],
413
+ {"type": "text", "text": text},
414
+ ]
415
+ }]
416
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
417
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
418
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
419
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
420
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
421
+ thread.start()
422
+ buffer = ""
423
+ yield progress_bar_html("Processing multimodal input...")
424
+ for new_text in streamer:
425
+ buffer += new_text
426
+ buffer = buffer.replace("<|im_end|>", "")
427
+ time.sleep(0.01)
428
+ yield buffer
429
+ final_response = buffer
430
+ else:
431
+ conversation = clean_chat_history(chat_history)
432
+ conversation.append({"role": "user", "content": text})
433
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
434
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
435
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
436
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
437
+ input_ids = input_ids.to(model.device)
438
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
439
+ generation_kwargs = {
440
+ "input_ids": input_ids,
441
+ "streamer": streamer,
442
+ "max_new_tokens": max_new_tokens,
443
+ "do_sample": True,
444
+ "top_p": top_p,
445
+ "top_k": top_k,
446
+ "temperature": temperature,
447
+ "num_beams": 1,
448
+ "repetition_penalty": repetition_penalty,
449
+ }
450
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
451
+ t.start()
452
+ outputs = []
453
+ yield progress_bar_html("Processing text...")
454
+ for new_text in streamer:
455
+ outputs.append(new_text)
456
+ yield "".join(outputs)
457
+ final_response = "".join(outputs)
458
+
459
+ # Yield the final text response.
460
+ yield final_response
461
+
462
+ # If TTS conversion is enabled, log the message and generate speech.
463
+ if convert_to_speech:
464
+ print("Generate Response to Generate Speech")
465
+ # Here tts_rate can be used to adjust parameters if needed.
466
+ output_file = asyncio.run(text_to_speech(final_response, tts_voice))
467
+ yield gr.Audio(output_file, autoplay=True)
468
+
469
+ with gr.Blocks() as demo:
470
+ with gr.Sidebar():
471
+ gr.Markdown("# TTS Conversion")
472
+ tts_rate_slider = gr.Slider(label="TTS Rate", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
473
+ tts_voice_radio = gr.Radio(choices=TTS_VOICES, label="Choose TTS Voice", value="en-US-JennyNeural")
474
+ convert_to_speech_checkbox = gr.Checkbox(label="Convert to Speech", value=False)
475
+
476
+ chat_interface = gr.ChatInterface(
477
+ fn=generate,
478
+ additional_inputs=[
479
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
480
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
481
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
482
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
483
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
484
+ # Pass TTS parameters to the generate function.
485
+ convert_to_speech_checkbox,
486
+ tts_rate_slider,
487
+ tts_voice_radio,
488
+ ],
489
+ examples=[
490
+ ["Write the Python Program for Array Rotation"],
491
+ [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
492
+ [{"text": "Describe the Ad", "files": ["examples/coca.mp4"]}],
493
+ [{"text": "Summarize the event in video", "files": ["examples/sky.mp4"]}],
494
+ [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
495
+ ["Who is Nikola Tesla, and why did he die?"],
496
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
497
+ ["What causes rainbows to form?"],
498
+ ],
499
+ cache_examples=False,
500
+ type="messages",
501
+ description="# **QwQ Edge: Multimodal (image upload uses Qwen2-VL) with TTS conversion**",
502
+ fill_height=True,
503
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Enter text or upload files"),
504
+ stop_btn="Stop Generation",
505
+ multimodal=True,
506
  )
507
 
508
+ if __name__ == "__main__":
509
+ demo.queue(max_size=20).launch(share=True)