Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,302 +1,509 @@
|
|
1 |
-
import subprocess
|
2 |
-
subprocess.run(
|
3 |
-
'pip install flash-attn==2.7.0.post2 --no-build-isolation',
|
4 |
-
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
5 |
-
shell=True
|
6 |
-
)
|
7 |
-
subprocess.run(
|
8 |
-
'pip install transformers',
|
9 |
-
shell=True
|
10 |
-
)
|
11 |
-
|
12 |
-
|
13 |
-
import spaces
|
14 |
import os
|
15 |
-
import
|
16 |
-
import
|
17 |
-
|
|
|
|
|
18 |
from threading import Thread
|
19 |
-
import base64
|
20 |
|
21 |
-
import torch
|
22 |
import gradio as gr
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
model_name = 'prithivMLmods/Raptor-X6' # Change as needed
|
29 |
-
use_thread = True # Generation happens in a background thread
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
|
|
|
33 |
torch_dtype=torch.bfloat16,
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
""
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
"""
|
86 |
-
|
87 |
-
|
88 |
-
Otherwise, return the entire text.
|
89 |
"""
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
"""
|
99 |
-
|
100 |
-
|
101 |
"""
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
#
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
}
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
yield chatbot, ""
|
165 |
-
thread_answer.join()
|
166 |
-
|
167 |
-
log_conversation(chatbot)
|
168 |
-
|
169 |
-
# Once final answer is complete, parse out HTML code block and
|
170 |
-
# return it as an artifact (iframe).
|
171 |
-
html_code = extract_html_code_block(full_answer)
|
172 |
-
sandbox_iframe = send_to_sandbox(html_code)
|
173 |
-
yield chatbot, sandbox_iframe
|
174 |
-
|
175 |
-
# ----------------------------------------------------------------------
|
176 |
-
# 7. Logging and Clearing
|
177 |
-
# ----------------------------------------------------------------------
|
178 |
-
def log_conversation(chatbot: List[List[str]]):
|
179 |
-
logger.info("[CONVERSATION]")
|
180 |
-
for i, (query, response) in enumerate(chatbot, 1):
|
181 |
-
logger.info(f"Q{i}: {query}\nA{i}: {response}")
|
182 |
-
|
183 |
-
def clear_chat():
|
184 |
-
return [], "", ""
|
185 |
-
|
186 |
-
# ----------------------------------------------------------------------
|
187 |
-
# 8. Gradio UI Setup
|
188 |
-
# ----------------------------------------------------------------------
|
189 |
-
css_code = """
|
190 |
-
.left_header {
|
191 |
-
display: flex;
|
192 |
-
flex-direction: column;
|
193 |
-
justify-content: center;
|
194 |
-
align-items: center;
|
195 |
-
}
|
196 |
-
|
197 |
-
.right_panel {
|
198 |
-
margin-top: 16px;
|
199 |
-
border: 1px solid #BFBFC4;
|
200 |
-
border-radius: 8px;
|
201 |
-
overflow: hidden;
|
202 |
-
}
|
203 |
-
|
204 |
-
.render_header {
|
205 |
-
height: 30px;
|
206 |
-
width: 100%;
|
207 |
-
padding: 5px 16px;
|
208 |
-
background-color: #f5f5f5;
|
209 |
-
}
|
210 |
-
|
211 |
-
.header_btn {
|
212 |
-
display: inline-block;
|
213 |
-
height: 10px;
|
214 |
-
width: 10px;
|
215 |
-
border-radius: 50%;
|
216 |
-
margin-right: 4px;
|
217 |
-
}
|
218 |
-
|
219 |
-
.render_header > .header_btn:nth-child(1) {
|
220 |
-
background-color: #f5222d;
|
221 |
-
}
|
222 |
-
|
223 |
-
.render_header > .header_btn:nth-child(2) {
|
224 |
-
background-color: #faad14;
|
225 |
-
}
|
226 |
-
.render_header > .header_btn:nth-child(3) {
|
227 |
-
background-color: #52c41a;
|
228 |
-
}
|
229 |
-
|
230 |
-
.right_content {
|
231 |
-
height: 920px;
|
232 |
-
display: flex;
|
233 |
-
flex-direction: column;
|
234 |
-
justify-content: center;
|
235 |
-
align-items: center;
|
236 |
-
}
|
237 |
|
238 |
-
.
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
"""
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
</div>
|
260 |
-
""")
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
)
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
)
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import random
|
3 |
+
import uuid
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import asyncio
|
7 |
from threading import Thread
|
|
|
8 |
|
|
|
9 |
import gradio as gr
|
10 |
+
import spaces
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
import edge_tts
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
from transformers import (
|
18 |
+
AutoModelForCausalLM,
|
19 |
+
AutoTokenizer,
|
20 |
+
TextIteratorStreamer,
|
21 |
+
Qwen2VLForConditionalGeneration,
|
22 |
+
AutoProcessor,
|
23 |
+
)
|
24 |
+
from transformers.image_utils import load_image
|
25 |
+
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
26 |
|
27 |
+
MAX_MAX_NEW_TOKENS = 2048
|
28 |
+
DEFAULT_MAX_NEW_TOKENS = 1024
|
29 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
|
|
|
30 |
|
31 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
# Load text-only model and tokenizer
|
34 |
+
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
36 |
model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
model_id,
|
38 |
+
device_map="auto",
|
39 |
torch_dtype=torch.bfloat16,
|
40 |
+
)
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
# Updated TTS voices list (all voices)
|
44 |
+
TTS_VOICES = [
|
45 |
+
"af-ZA-AdriNeural",
|
46 |
+
"af-ZA-WillemNeural",
|
47 |
+
"am-ET-AmehaNeural",
|
48 |
+
"am-ET-MekdesNeural",
|
49 |
+
"ar-AE-FatimaNeural",
|
50 |
+
"ar-AE-HamdanNeural",
|
51 |
+
"ar-BH-LailaNeural",
|
52 |
+
"ar-BH-MajedNeural",
|
53 |
+
"ar-DZ-AminaNeural",
|
54 |
+
"ar-DZ-IsmaelNeural",
|
55 |
+
"ar-EG-SalmaNeural",
|
56 |
+
"ar-EG-OmarNeural",
|
57 |
+
"ar-IQ-LanaNeural",
|
58 |
+
"ar-IQ-BassamNeural",
|
59 |
+
"ar-JO-SanaNeural",
|
60 |
+
"ar-JO-TaimNeural",
|
61 |
+
"ar-KW-NouraNeural",
|
62 |
+
"ar-KW-FahedNeural",
|
63 |
+
"ar-LB-LaylaNeural",
|
64 |
+
"ar-LB-RamiNeural",
|
65 |
+
"ar-LY-ImanNeural",
|
66 |
+
"ar-LY-OmarNeural",
|
67 |
+
"ar-MA-MounaNeural",
|
68 |
+
"ar-MA-JamalNeural",
|
69 |
+
"ar-OM-AyshaNeural",
|
70 |
+
"ar-OM-AbdullahNeural",
|
71 |
+
"ar-QA-AmalNeural",
|
72 |
+
"ar-QA-MoazNeural",
|
73 |
+
"ar-SA-ZariyahNeural",
|
74 |
+
"ar-SA-HamedNeural",
|
75 |
+
"ar-SY-AmanyNeural",
|
76 |
+
"ar-SY-LaithNeural",
|
77 |
+
"ar-TN-ReemNeural",
|
78 |
+
"ar-TN-SeifNeural",
|
79 |
+
"ar-YE-MaryamNeural",
|
80 |
+
"ar-YE-SalehNeural",
|
81 |
+
"az-AZ-BabekNeural",
|
82 |
+
"az-AZ-BanuNeural",
|
83 |
+
"bg-BG-BorislavNeural",
|
84 |
+
"bg-BG-KalinaNeural",
|
85 |
+
"bn-BD-NabanitaNeural",
|
86 |
+
"bn-BD-PradeepNeural",
|
87 |
+
"bn-IN-TanishaNeural",
|
88 |
+
"bn-IN-SwapanNeural",
|
89 |
+
"bs-BA-GoranNeural",
|
90 |
+
"bs-BA-VesnaNeural",
|
91 |
+
"ca-ES-JoanaNeural",
|
92 |
+
"ca-ES-AlbaNeural",
|
93 |
+
"ca-ES-EnricNeural",
|
94 |
+
"cs-CZ-AntoninNeural",
|
95 |
+
"cs-CZ-VlastaNeural",
|
96 |
+
"cy-GB-NiaNeural",
|
97 |
+
"cy-GB-AledNeural",
|
98 |
+
"da-DK-ChristelNeural",
|
99 |
+
"da-DK-JeppeNeural",
|
100 |
+
"de-AT-IngridNeural",
|
101 |
+
"de-AT-JonasNeural",
|
102 |
+
"de-CH-LeniNeural",
|
103 |
+
"de-CH-JanNeural",
|
104 |
+
"de-DE-KatjaNeural",
|
105 |
+
"de-DE-ConradNeural",
|
106 |
+
"el-GR-AthinaNeural",
|
107 |
+
"el-GR-NestorasNeural",
|
108 |
+
"en-AU-AnnetteNeural",
|
109 |
+
"en-AU-MichaelNeural",
|
110 |
+
"en-CA-ClaraNeural",
|
111 |
+
"en-CA-LiamNeural",
|
112 |
+
"en-GB-SoniaNeural",
|
113 |
+
"en-GB-RyanNeural",
|
114 |
+
"en-GH-EsiNeural",
|
115 |
+
"en-GH-KwameNeural",
|
116 |
+
"en-HK-YanNeural",
|
117 |
+
"en-HK-TrevorNeural",
|
118 |
+
"en-IE-EmilyNeural",
|
119 |
+
"en-IE-ConnorNeural",
|
120 |
+
"en-IN-NeerjaNeural",
|
121 |
+
"en-IN-PrabhasNeural",
|
122 |
+
"en-KE-ChantelleNeural",
|
123 |
+
"en-KE-ChilembaNeural",
|
124 |
+
"en-NG-EzinneNeural",
|
125 |
+
"en-NG-AbechiNeural",
|
126 |
+
"en-NZ-MollyNeural",
|
127 |
+
"en-NZ-MitchellNeural",
|
128 |
+
"en-PH-RosaNeural",
|
129 |
+
"en-PH-JamesNeural",
|
130 |
+
"en-SG-LunaNeural",
|
131 |
+
"en-SG-WayneNeural",
|
132 |
+
"en-TZ-ImaniNeural",
|
133 |
+
"en-TZ-DaudiNeural",
|
134 |
+
"en-US-JennyNeural",
|
135 |
+
"en-US-GuyNeural",
|
136 |
+
"en-ZA-LeahNeural",
|
137 |
+
"en-ZA-LukeNeural",
|
138 |
+
"es-AR-ElenaNeural",
|
139 |
+
"es-AR-TomasNeural",
|
140 |
+
"es-BO-SofiaNeural",
|
141 |
+
"es-BO-MarceloNeural",
|
142 |
+
"es-CL-CatalinaNeural",
|
143 |
+
"es-CL-LorenzoNeural",
|
144 |
+
"es-CO-SalomeNeural",
|
145 |
+
"es-CO-GonzaloNeural",
|
146 |
+
"es-CR-MariaNeural",
|
147 |
+
"es-CR-JuanNeural",
|
148 |
+
"es-CU-BelkysNeural",
|
149 |
+
"es-CU-ManuelNeural",
|
150 |
+
"es-DO-RamonaNeural",
|
151 |
+
"es-DO-EmilioNeural",
|
152 |
+
"es-EC-AndreaNeural",
|
153 |
+
"es-EC-LuisNeural",
|
154 |
+
"es-ES-ElviraNeural",
|
155 |
+
"es-ES-AlvaroNeural",
|
156 |
+
"es-GQ-TeresaNeural",
|
157 |
+
"es-GQ-JavierNeural",
|
158 |
+
"es-GT-MartaNeural",
|
159 |
+
"es-GT-AndresNeural",
|
160 |
+
"es-HN-KarlaNeural",
|
161 |
+
"es-HN-CarlosNeural",
|
162 |
+
"es-MX-DaliaNeural",
|
163 |
+
"es-MX-JorgeNeural",
|
164 |
+
"es-NI-YolandaNeural",
|
165 |
+
"es-NI-FedericoNeural",
|
166 |
+
"es-PA-MargaritaNeural",
|
167 |
+
"es-PA-RobertoNeural",
|
168 |
+
"es-PE-CamilaNeural",
|
169 |
+
"es-PE-AlexNeural",
|
170 |
+
"es-PR-KarinaNeural",
|
171 |
+
"es-PR-VictorNeural",
|
172 |
+
"es-PY-TaniaNeural",
|
173 |
+
"es-PY-MarioNeural",
|
174 |
+
"es-SV-LorenaNeural",
|
175 |
+
"es-SV-RodrigoNeural",
|
176 |
+
"es-US-SaraNeural",
|
177 |
+
"es-US-AlonsoNeural",
|
178 |
+
"es-UY-ValentinaNeural",
|
179 |
+
"es-UY-MateoNeural",
|
180 |
+
"es-VE-PaolaNeural",
|
181 |
+
"es-VE-SebastianNeural",
|
182 |
+
"et-EE-AnuNeural",
|
183 |
+
"et-EE-KertNeural",
|
184 |
+
"eu-ES-AinhoaNeural",
|
185 |
+
"eu-ES-AnderNeural",
|
186 |
+
"fa-IR-DilaraNeural",
|
187 |
+
"fa-IR-FaridNeural",
|
188 |
+
"fi-FI-NooraNeural",
|
189 |
+
"fi-FI-HarriNeural",
|
190 |
+
"fil-PH-BlessicaNeural",
|
191 |
+
"fil-PH-AngeloNeural",
|
192 |
+
"fr-BE-CharlineNeural",
|
193 |
+
"fr-BE-GerardNeural",
|
194 |
+
"fr-CA-SylvieNeural",
|
195 |
+
"fr-CA-AntoineNeural",
|
196 |
+
"fr-CH-ArianeNeural",
|
197 |
+
"fr-CH-GuillaumeNeural",
|
198 |
+
"fr-FR-DeniseNeural",
|
199 |
+
"fr-FR-HenriNeural",
|
200 |
+
"ga-IE-OrlaNeural",
|
201 |
+
"ga-IE-ColmNeural",
|
202 |
+
"gl-ES-SoniaNeural",
|
203 |
+
"gl-ES-XiaoqiangNeural",
|
204 |
+
"gu-IN-DhwaniNeural",
|
205 |
+
"gu-IN-NiranjanNeural",
|
206 |
+
"ha-NG-AishaNeural",
|
207 |
+
"ha-NG-YusufNeural",
|
208 |
+
"he-IL-HilaNeural",
|
209 |
+
"he-IL-AvriNeural",
|
210 |
+
"hi-IN-SwaraNeural",
|
211 |
+
"hi-IN-MadhurNeural",
|
212 |
+
"hr-HR-GabrijelaNeural",
|
213 |
+
"hr-HR-SreckoNeural",
|
214 |
+
"hu-HU-NoemiNeural",
|
215 |
+
"hu-HU-TamasNeural",
|
216 |
+
"hy-AM-AnushNeural",
|
217 |
+
"hy-AM-HaykNeural",
|
218 |
+
"id-ID-ArdiNeural",
|
219 |
+
"id-ID-GadisNeural",
|
220 |
+
"ig-NG-AdaNeural",
|
221 |
+
"ig-NG-EzeNeural",
|
222 |
+
"is-IS-GudrunNeural",
|
223 |
+
"is-IS-GunnarNeural",
|
224 |
+
"it-IT-ElsaNeural",
|
225 |
+
"it-IT-DiegoNeural",
|
226 |
+
"ja-JP-NanamiNeural",
|
227 |
+
"ja-JP-KeitaNeural",
|
228 |
+
"jv-ID-DianNeural",
|
229 |
+
"jv-ID-GustiNeural",
|
230 |
+
"ka-GE-EkaNeural",
|
231 |
+
# ... (truncated for brevity; include all voices as needed)
|
232 |
+
]
|
233 |
+
|
234 |
+
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
235 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
236 |
+
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
237 |
+
MODEL_ID,
|
238 |
+
trust_remote_code=True,
|
239 |
+
torch_dtype=torch.float16
|
240 |
+
).to("cuda").eval()
|
241 |
+
|
242 |
+
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
243 |
+
"""Convert text to speech using Edge TTS and save as MP3"""
|
244 |
+
communicate = edge_tts.Communicate(text, voice)
|
245 |
+
await communicate.save(output_file)
|
246 |
+
return output_file
|
247 |
+
|
248 |
+
def clean_chat_history(chat_history):
|
249 |
"""
|
250 |
+
Filter out any chat entries whose "content" is not a string.
|
251 |
+
This helps prevent errors when concatenating previous messages.
|
|
|
252 |
"""
|
253 |
+
cleaned = []
|
254 |
+
for msg in chat_history:
|
255 |
+
if isinstance(msg, dict) and isinstance(msg.get("content"), str):
|
256 |
+
cleaned.append(msg)
|
257 |
+
return cleaned
|
258 |
+
|
259 |
+
# Environment variables and parameters for Stable Diffusion XL (left in case needed in the future)
|
260 |
+
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
261 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
262 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
263 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
264 |
+
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
|
265 |
+
|
266 |
+
# Load the SDXL pipeline (not used in the current configuration)
|
267 |
+
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
|
268 |
+
MODEL_ID_SD,
|
269 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
270 |
+
use_safetensors=True,
|
271 |
+
add_watermarker=False,
|
272 |
+
).to(device)
|
273 |
+
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
274 |
+
if torch.cuda.is_available():
|
275 |
+
sd_pipe.text_encoder = sd_pipe.text_encoder.half()
|
276 |
+
if USE_TORCH_COMPILE:
|
277 |
+
sd_pipe.compile()
|
278 |
+
if ENABLE_CPU_OFFLOAD:
|
279 |
+
sd_pipe.enable_model_cpu_offload()
|
280 |
+
|
281 |
+
MAX_SEED = np.iinfo(np.int32).max
|
282 |
+
|
283 |
+
def save_image(img: Image.Image) -> str:
|
284 |
+
"""Save a PIL image with a unique filename and return the path."""
|
285 |
+
unique_name = str(uuid.uuid4()) + ".png"
|
286 |
+
img.save(unique_name)
|
287 |
+
return unique_name
|
288 |
+
|
289 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
290 |
+
if randomize_seed:
|
291 |
+
seed = random.randint(0, MAX_SEED)
|
292 |
+
return seed
|
293 |
+
|
294 |
+
def progress_bar_html(label: str) -> str:
|
295 |
"""
|
296 |
+
Returns an HTML snippet for a thin progress bar with a label.
|
297 |
+
The progress bar is styled as a dark red animated bar.
|
298 |
"""
|
299 |
+
return f'''
|
300 |
+
<div style="display: flex; align-items: center;">
|
301 |
+
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
302 |
+
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
|
303 |
+
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
304 |
+
</div>
|
305 |
+
</div>
|
306 |
+
<style>
|
307 |
+
@keyframes loading {{
|
308 |
+
0% {{ transform: translateX(-100%); }}
|
309 |
+
100% {{ transform: translateX(100%); }}
|
310 |
+
}}
|
311 |
+
</style>
|
312 |
+
'''
|
313 |
+
|
314 |
+
def downsample_video(video_path):
|
315 |
+
"""
|
316 |
+
Downsamples the video to 10 evenly spaced frames.
|
317 |
+
Each frame is returned as a PIL image along with its timestamp.
|
318 |
+
"""
|
319 |
+
vidcap = cv2.VideoCapture(video_path)
|
320 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
321 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
322 |
+
frames = []
|
323 |
+
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
|
324 |
+
for i in frame_indices:
|
325 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
326 |
+
success, image = vidcap.read()
|
327 |
+
if success:
|
328 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
329 |
+
pil_image = Image.fromarray(image)
|
330 |
+
timestamp = round(i / fps, 2)
|
331 |
+
frames.append((pil_image, timestamp))
|
332 |
+
vidcap.release()
|
333 |
+
return frames
|
334 |
+
|
335 |
+
@spaces.GPU(duration=60, enable_queue=True)
|
336 |
+
def generate_image_fn(
|
337 |
+
prompt: str,
|
338 |
+
negative_prompt: str = "",
|
339 |
+
use_negative_prompt: bool = False,
|
340 |
+
seed: int = 1,
|
341 |
+
width: int = 1024,
|
342 |
+
height: int = 1024,
|
343 |
+
guidance_scale: float = 3,
|
344 |
+
num_inference_steps: int = 25,
|
345 |
+
randomize_seed: bool = False,
|
346 |
+
use_resolution_binning: bool = True,
|
347 |
+
num_images: int = 1,
|
348 |
+
progress=gr.Progress(track_tqdm=True),
|
349 |
+
):
|
350 |
+
"""(Image generation function is preserved but not called in the current configuration)"""
|
351 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
352 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
353 |
+
options = {
|
354 |
+
"prompt": [prompt] * num_images,
|
355 |
+
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
|
356 |
+
"width": width,
|
357 |
+
"height": height,
|
358 |
+
"guidance_scale": guidance_scale,
|
359 |
+
"num_inference_steps": num_inference_steps,
|
360 |
+
"generator": generator,
|
361 |
+
"output_type": "pil",
|
362 |
}
|
363 |
+
if use_resolution_binning:
|
364 |
+
options["use_resolution_binning"] = True
|
365 |
+
images = []
|
366 |
+
for i in range(0, num_images, BATCH_SIZE):
|
367 |
+
batch_options = options.copy()
|
368 |
+
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
|
369 |
+
if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
|
370 |
+
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
|
371 |
+
if device.type == "cuda":
|
372 |
+
with torch.autocast("cuda", dtype=torch.float16):
|
373 |
+
outputs = sd_pipe(**batch_options)
|
374 |
+
else:
|
375 |
+
outputs = sd_pipe(**batch_options)
|
376 |
+
images.extend(outputs.images)
|
377 |
+
image_paths = [save_image(img) for img in images]
|
378 |
+
return image_paths, seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
+
@spaces.GPU
|
381 |
+
def generate(
|
382 |
+
input_dict: dict,
|
383 |
+
chat_history: list[dict],
|
384 |
+
max_new_tokens: int = 1024,
|
385 |
+
temperature: float = 0.6,
|
386 |
+
top_p: float = 0.9,
|
387 |
+
top_k: int = 50,
|
388 |
+
repetition_penalty: float = 1.2,
|
389 |
+
convert_to_speech: bool = False,
|
390 |
+
tts_rate: float = 1.0,
|
391 |
+
tts_voice: str = "en-US-JennyNeural",
|
392 |
+
):
|
393 |
+
"""
|
394 |
+
Generates chatbot responses with support for multimodal input and TTS conversion.
|
395 |
+
When files (images or videos) are provided, Qwen2VL is used.
|
396 |
+
Otherwise, the FastThink-0.5B text model is used.
|
397 |
+
After generating the response, if convert_to_speech is True the text is passed to the TTS function.
|
398 |
+
"""
|
399 |
+
text = input_dict["text"].strip()
|
400 |
+
files = input_dict.get("files", [])
|
|
|
|
|
401 |
|
402 |
+
# Determine which branch to use: multimodal (if files provided) or text-only.
|
403 |
+
if files:
|
404 |
+
# Process uploaded files as images (or videos)
|
405 |
+
if len(files) > 1:
|
406 |
+
images = [load_image(image) for image in files]
|
407 |
+
else:
|
408 |
+
images = [load_image(files[0])]
|
409 |
+
messages = [{
|
410 |
+
"role": "user",
|
411 |
+
"content": [
|
412 |
+
*[{"type": "image", "image": image} for image in images],
|
413 |
+
{"type": "text", "text": text},
|
414 |
+
]
|
415 |
+
}]
|
416 |
+
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
417 |
+
inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
|
418 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
419 |
+
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
|
420 |
+
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
|
421 |
+
thread.start()
|
422 |
+
buffer = ""
|
423 |
+
yield progress_bar_html("Processing multimodal input...")
|
424 |
+
for new_text in streamer:
|
425 |
+
buffer += new_text
|
426 |
+
buffer = buffer.replace("<|im_end|>", "")
|
427 |
+
time.sleep(0.01)
|
428 |
+
yield buffer
|
429 |
+
final_response = buffer
|
430 |
+
else:
|
431 |
+
conversation = clean_chat_history(chat_history)
|
432 |
+
conversation.append({"role": "user", "content": text})
|
433 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
434 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
435 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
436 |
+
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
437 |
+
input_ids = input_ids.to(model.device)
|
438 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
439 |
+
generation_kwargs = {
|
440 |
+
"input_ids": input_ids,
|
441 |
+
"streamer": streamer,
|
442 |
+
"max_new_tokens": max_new_tokens,
|
443 |
+
"do_sample": True,
|
444 |
+
"top_p": top_p,
|
445 |
+
"top_k": top_k,
|
446 |
+
"temperature": temperature,
|
447 |
+
"num_beams": 1,
|
448 |
+
"repetition_penalty": repetition_penalty,
|
449 |
+
}
|
450 |
+
t = Thread(target=model.generate, kwargs=generation_kwargs)
|
451 |
+
t.start()
|
452 |
+
outputs = []
|
453 |
+
yield progress_bar_html("Processing text...")
|
454 |
+
for new_text in streamer:
|
455 |
+
outputs.append(new_text)
|
456 |
+
yield "".join(outputs)
|
457 |
+
final_response = "".join(outputs)
|
458 |
+
|
459 |
+
# Yield the final text response.
|
460 |
+
yield final_response
|
461 |
+
|
462 |
+
# If TTS conversion is enabled, log the message and generate speech.
|
463 |
+
if convert_to_speech:
|
464 |
+
print("Generate Response to Generate Speech")
|
465 |
+
# Here tts_rate can be used to adjust parameters if needed.
|
466 |
+
output_file = asyncio.run(text_to_speech(final_response, tts_voice))
|
467 |
+
yield gr.Audio(output_file, autoplay=True)
|
468 |
+
|
469 |
+
with gr.Blocks() as demo:
|
470 |
+
with gr.Sidebar():
|
471 |
+
gr.Markdown("# TTS Conversion")
|
472 |
+
tts_rate_slider = gr.Slider(label="TTS Rate", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
|
473 |
+
tts_voice_radio = gr.Radio(choices=TTS_VOICES, label="Choose TTS Voice", value="en-US-JennyNeural")
|
474 |
+
convert_to_speech_checkbox = gr.Checkbox(label="Convert to Speech", value=False)
|
475 |
+
|
476 |
+
chat_interface = gr.ChatInterface(
|
477 |
+
fn=generate,
|
478 |
+
additional_inputs=[
|
479 |
+
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
|
480 |
+
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
|
481 |
+
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
|
482 |
+
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
|
483 |
+
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
|
484 |
+
# Pass TTS parameters to the generate function.
|
485 |
+
convert_to_speech_checkbox,
|
486 |
+
tts_rate_slider,
|
487 |
+
tts_voice_radio,
|
488 |
+
],
|
489 |
+
examples=[
|
490 |
+
["Write the Python Program for Array Rotation"],
|
491 |
+
[{"text": "Summarize the letter", "files": ["examples/1.png"]}],
|
492 |
+
[{"text": "Describe the Ad", "files": ["examples/coca.mp4"]}],
|
493 |
+
[{"text": "Summarize the event in video", "files": ["examples/sky.mp4"]}],
|
494 |
+
[{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
|
495 |
+
["Who is Nikola Tesla, and why did he die?"],
|
496 |
+
[{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
|
497 |
+
["What causes rainbows to form?"],
|
498 |
+
],
|
499 |
+
cache_examples=False,
|
500 |
+
type="messages",
|
501 |
+
description="# **QwQ Edge: Multimodal (image upload uses Qwen2-VL) with TTS conversion**",
|
502 |
+
fill_height=True,
|
503 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Enter text or upload files"),
|
504 |
+
stop_btn="Stop Generation",
|
505 |
+
multimodal=True,
|
506 |
)
|
507 |
|
508 |
+
if __name__ == "__main__":
|
509 |
+
demo.queue(max_size=20).launch(share=True)
|