srinuksv commited on
Commit
4aa1ed6
·
verified ·
1 Parent(s): 6c79ba4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import ffmpeg
4
+
5
+ import numpy as np
6
+ import gradio as gr
7
+ import soundfile as sf
8
+
9
+ import modelscope_studio.components.base as ms
10
+ import modelscope_studio.components.antd as antd
11
+ import gradio.processing_utils as processing_utils
12
+
13
+ from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
14
+ from gradio_client import utils as client_utils
15
+ from qwen_omni_utils import process_mm_info
16
+ from argparse import ArgumentParser
17
+
18
+ def _load_model_processor(args):
19
+ if args.cpu_only:
20
+ device_map = 'cpu'
21
+ else:
22
+ device_map = 'auto'
23
+
24
+ # Check if flash-attn2 flag is enabled and load model accordingly
25
+ if args.flash_attn2:
26
+ model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path,
27
+ torch_dtype='auto',
28
+ attn_implementation='flash_attention_2',
29
+ device_map=device_map)
30
+ else:
31
+ model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path, device_map=device_map)
32
+
33
+ processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path)
34
+ return model, processor
35
+
36
+ def _launch_demo(args, model, processor):
37
+ # Voice settings
38
+ VOICE_LIST = ['Chelsie', 'Ethan']
39
+ DEFAULT_VOICE = 'Chelsie'
40
+
41
+ default_system_prompt = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'
42
+
43
+ language = args.ui_language
44
+
45
+ def get_text(text: str, cn_text: str):
46
+ if language == 'en':
47
+ return text
48
+ if language == 'zh':
49
+ return cn_text
50
+ return text
51
+
52
+ def convert_webm_to_mp4(input_file, output_file):
53
+ try:
54
+ (
55
+ ffmpeg
56
+ .input(input_file)
57
+ .output(output_file, acodec='aac', ar='16000', audio_bitrate='192k')
58
+ .run(quiet=True, overwrite_output=True)
59
+ )
60
+ print(f"Conversion successful: {output_file}")
61
+ except ffmpeg.Error as e:
62
+ print("An error occurred during conversion.")
63
+ print(e.stderr.decode('utf-8'))
64
+
65
+ def format_history(history: list, system_prompt: str):
66
+ messages = []
67
+ messages.append({"role": "system", "content": system_prompt})
68
+ for item in history:
69
+ if isinstance(item["content"], str):
70
+ messages.append({"role": item['role'], "content": item['content']})
71
+ elif item["role"] == "user" and (isinstance(item["content"], list) or
72
+ isinstance(item["content"], tuple)):
73
+ file_path = item["content"][0]
74
+
75
+ mime_type = client_utils.get_mimetype(file_path)
76
+ if mime_type.startswith("image"):
77
+ messages.append({
78
+ "role":
79
+ item['role'],
80
+ "content": [{
81
+ "type": "image",
82
+ "image": file_path
83
+ }]
84
+ })
85
+ elif mime_type.startswith("video"):
86
+ messages.append({
87
+ "role":
88
+ item['role'],
89
+ "content": [{
90
+ "type": "video",
91
+ "video": file_path
92
+ }]
93
+ })
94
+ elif mime_type.startswith("audio"):
95
+ messages.append({
96
+ "role":
97
+ item['role'],
98
+ "content": [{
99
+ "type": "audio",
100
+ "audio": file_path,
101
+ }]
102
+ })
103
+ return messages
104
+
105
+ def predict(messages, voice=DEFAULT_VOICE):
106
+ print('predict history: ', messages)
107
+
108
+ text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
109
+
110
+ audios, images, videos = process_mm_info(messages, True)
111
+
112
+ inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True)
113
+ inputs = inputs.to(model.device).to(model.dtype)
114
+
115
+ text_ids, audio = model.generate(**inputs, spk=voice, use_audio_in_video=True)
116
+
117
+ response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
118
+ response = response[0].split("\n")[-1]
119
+ yield {"type": "text", "data": response}
120
+
121
+ audio = np.array(audio * 32767).astype(np.int16)
122
+ wav_io = io.BytesIO()
123
+ sf.write(wav_io, audio, samplerate=24000, format="WAV")
124
+ wav_io.seek(0)
125
+ wav_bytes = wav_io.getvalue()
126
+ audio_path = processing_utils.save_bytes_to_cache(
127
+ wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
128
+ yield {"type": "audio", "data": audio_path}
129
+
130
+ def media_predict(audio, video, history, system_prompt, voice_choice):
131
+ # First yield
132
+ yield (
133
+ None, # microphone
134
+ None, # webcam
135
+ history, # media_chatbot
136
+ gr.update(visible=False), # submit_btn
137
+ gr.update(visible=True), # stop_btn
138
+ )
139
+
140
+ if video is not None:
141
+ convert_webm_to_mp4(video, video.replace('.webm', '.mp4'))
142
+ video = video.replace(".webm", ".mp4")
143
+ files = [audio, video]
144
+
145
+ for f in files:
146
+ if f:
147
+ history.append({"role": "user", "content": (f, )})
148
+
149
+ formatted_history = format_history(history=history,
150
+ system_prompt=system_prompt,)
151
+
152
+
153
+ history.append({"role": "assistant", "content": ""})
154
+
155
+ for chunk in predict(formatted_history, voice_choice):
156
+ if chunk["type"] == "text":
157
+ history[-1]["content"] = chunk["data"]
158
+ yield (
159
+ None, # microphone
160
+ None, # webcam
161
+ history, # media_chatbot
162
+ gr.update(visible=False), # submit_btn
163
+ gr.update(visible=True), # stop_btn
164
+ )
165
+ if chunk["type"] == "audio":
166
+ history.append({
167
+ "role": "assistant",
168
+ "content": gr.Audio(chunk["data"])
169
+ })
170
+
171
+ # Final yield
172
+ yield (
173
+ None, # microphone
174
+ None, # webcam
175
+ history, # media_chatbot
176
+ gr.update(visible=True), # submit_btn
177
+ gr.update(visible=False), # stop_btn
178
+ )
179
+
180
+ def chat_predict(text, audio, image, video, history, system_prompt, voice_choice):
181
+ # Process text input
182
+ if text:
183
+ history.append({"role": "user", "content": text})
184
+
185
+ # Process audio input
186
+ if audio:
187
+ history.append({"role": "user", "content": (audio, )})
188
+
189
+ # Process image input
190
+ if image:
191
+ history.append({"role": "user", "content": (image, )})
192
+
193
+ # Process video input
194
+ if video:
195
+ history.append({"role": "user", "content": (video, )})
196
+
197
+ formatted_history = format_history(history=history,
198
+ system_prompt=system_prompt)
199
+
200
+ yield None, None, None, None, history
201
+
202
+ history.append({"role": "assistant", "content": ""})
203
+ for chunk in predict(formatted_history, voice_choice):
204
+ if chunk["type"] == "text":
205
+ history[-1]["content"] = chunk["data"]
206
+ yield gr.skip(), gr.skip(), gr.skip(), gr.skip(
207
+ ), history
208
+ if chunk["type"] == "audio":
209
+ history.append({
210
+ "role": "assistant",
211
+ "content": gr.Audio(chunk["data"])
212
+ })
213
+ yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
214
+
215
+ with gr.Blocks() as demo, ms.Application(), antd.ConfigProvider():
216
+ with gr.Sidebar(open=False):
217
+ system_prompt_textbox = gr.Textbox(label="System Prompt",
218
+ value=default_system_prompt)
219
+ with antd.Flex(gap="small", justify="center", align="center"):
220
+ with antd.Flex(vertical=True, gap="small", align="center"):
221
+ antd.Typography.Title("Qwen2.5-Omni Demo",
222
+ level=1,
223
+ elem_style=dict(margin=0, fontSize=28))
224
+ with antd.Flex(vertical=True, gap="small"):
225
+ antd.Typography.Text(get_text("🎯 Instructions for use:",
226
+ "🎯 使用说明:"),
227
+ strong=True)
228
+ antd.Typography.Text(
229
+ get_text(
230
+ "1️⃣ Click the Audio Record button or the Camera Record button.",
231
+ "1️⃣ 点击音频录制按钮,或摄像头-录制按钮"))
232
+ antd.Typography.Text(
233
+ get_text("2️⃣ Input audio or video.", "2️⃣ 输入音频或者视频"))
234
+ antd.Typography.Text(
235
+ get_text(
236
+ "3️⃣ Click the submit button and wait for the model's response.",
237
+ "3️⃣ 点击提交并等待模型的回答"))
238
+ voice_choice = gr.Dropdown(label="Voice Choice",
239
+ choices=VOICE_LIST,
240
+ value=DEFAULT_VOICE)
241
+ with gr.Tabs():
242
+ with gr.Tab("Online"):
243
+ with gr.Row():
244
+ with gr.Column(scale=1):
245
+ microphone = gr.Audio(sources=['microphone'],
246
+ type="filepath")
247
+ webcam = gr.Video(sources=['webcam'],
248
+ height=400,
249
+ include_audio=True)
250
+ submit_btn = gr.Button(get_text("Submit", "提交"),
251
+ variant="primary")
252
+ stop_btn = gr.Button(get_text("Stop", "停止"), visible=False)
253
+ clear_btn = gr.Button(get_text("Clear History", "清除历史"))
254
+ with gr.Column(scale=2):
255
+ media_chatbot = gr.Chatbot(height=650, type="messages")
256
+
257
+ def clear_history():
258
+ return [], gr.update(value=None), gr.update(value=None)
259
+
260
+ submit_event = submit_btn.click(fn=media_predict,
261
+ inputs=[
262
+ microphone, webcam,
263
+ media_chatbot,
264
+ system_prompt_textbox,
265
+ voice_choice
266
+ ],
267
+ outputs=[
268
+ microphone, webcam,
269
+ media_chatbot, submit_btn,
270
+ stop_btn
271
+ ])
272
+ stop_btn.click(
273
+ fn=lambda:
274
+ (gr.update(visible=True), gr.update(visible=False)),
275
+ inputs=None,
276
+ outputs=[submit_btn, stop_btn],
277
+ cancels=[submit_event],
278
+ queue=False)
279
+ clear_btn.click(fn=clear_history,
280
+ inputs=None,
281
+ outputs=[media_chatbot, microphone, webcam])
282
+
283
+ with gr.Tab("Offline"):
284
+ chatbot = gr.Chatbot(type="messages", height=650)
285
+
286
+ # Media upload section in one row
287
+ with gr.Row(equal_height=True):
288
+ audio_input = gr.Audio(sources=["upload"],
289
+ type="filepath",
290
+ label="Upload Audio",
291
+ elem_classes="media-upload",
292
+ scale=1)
293
+ image_input = gr.Image(sources=["upload"],
294
+ type="filepath",
295
+ label="Upload Image",
296
+ elem_classes="media-upload",
297
+ scale=1)
298
+ video_input = gr.Video(sources=["upload"],
299
+ label="Upload Video",
300
+ elem_classes="media-upload",
301
+ scale=1)
302
+
303
+ # Text input section
304
+ text_input = gr.Textbox(show_label=False,
305
+ placeholder="Enter text here...")
306
+
307
+ # Control buttons
308
+ with gr.Row():
309
+ submit_btn = gr.Button(get_text("Submit", "提交"),
310
+ variant="primary",
311
+ size="lg")
312
+ stop_btn = gr.Button(get_text("Stop", "停止"),
313
+ visible=False,
314
+ size="lg")
315
+ clear_btn = gr.Button(get_text("Clear History", "清除历史"),
316
+ size="lg")
317
+
318
+ def clear_chat_history():
319
+ return [], gr.update(value=None), gr.update(
320
+ value=None), gr.update(value=None), gr.update(value=None)
321
+
322
+ submit_event = gr.on(
323
+ triggers=[submit_btn.click, text_input.submit],
324
+ fn=chat_predict,
325
+ inputs=[
326
+ text_input, audio_input, image_input, video_input, chatbot,
327
+ system_prompt_textbox, voice_choice
328
+ ],
329
+ outputs=[
330
+ text_input, audio_input, image_input, video_input, chatbot
331
+ ])
332
+
333
+ stop_btn.click(fn=lambda:
334
+ (gr.update(visible=True), gr.update(visible=False)),
335
+ inputs=None,
336
+ outputs=[submit_btn, stop_btn],
337
+ cancels=[submit_event],
338
+ queue=False)
339
+
340
+ clear_btn.click(fn=clear_chat_history,
341
+ inputs=None,
342
+ outputs=[
343
+ chatbot, text_input, audio_input, image_input,
344
+ video_input
345
+ ])
346
+
347
+ # Add some custom CSS to improve the layout
348
+ gr.HTML("""
349
+ <style>
350
+ .media-upload {
351
+ margin: 10px;
352
+ min-height: 160px;
353
+ }
354
+ .media-upload > .wrap {
355
+ border: 2px dashed #ccc;
356
+ border-radius: 8px;
357
+ padding: 10px;
358
+ height: 100%;
359
+ }
360
+ .media-upload:hover > .wrap {
361
+ border-color: #666;
362
+ }
363
+ /* Make upload areas equal width */
364
+ .media-upload {
365
+ flex: 1;
366
+ min-width: 0;
367
+ }
368
+ </style>
369
+ """)
370
+
371
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
372
+ ssr_mode=False,
373
+ share=args.share,
374
+ inbrowser=args.inbrowser,
375
+ server_port=args.server_port,
376
+ server_name=args.server_name,)
377
+
378
+
379
+ DEFAULT_CKPT_PATH = "Qwen/Qwen2.5-Omni-7B"
380
+ def _get_args():
381
+ parser = ArgumentParser()
382
+
383
+ parser.add_argument('-c',
384
+ '--checkpoint-path',
385
+ type=str,
386
+ default=DEFAULT_CKPT_PATH,
387
+ help='Checkpoint name or path, default to %(default)r')
388
+ parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
389
+
390
+ parser.add_argument('--flash-attn2',
391
+ action='store_true',
392
+ default=False,
393
+ help='Enable flash_attention_2 when loading the model.')
394
+ parser.add_argument('--share',
395
+ action='store_true',
396
+ default=False,
397
+ help='Create a publicly shareable link for the interface.')
398
+ parser.add_argument('--inbrowser',
399
+ action='store_true',
400
+ default=False,
401
+ help='Automatically launch the interface in a new tab on the default browser.')
402
+ parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
403
+ parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
404
+ parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='en', help='Display language for the UI.')
405
+
406
+ args = parser.parse_args()
407
+ return args
408
+
409
+ if __name__ == "__main__":
410
+ args = _get_args()
411
+ args.share = True
412
+ model, processor = _load_model_processor(args)
413
+ _launch_demo(args, model, processor)