prithivMLmods commited on
Commit
8ec6920
·
verified ·
1 Parent(s): 0d5b113

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -35
app.py CHANGED
@@ -83,12 +83,12 @@ orpheus_tts_model.to(tts_device)
83
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
84
  print(f"Orpheus TTS model loaded to {tts_device}")
85
 
86
- # Global parameters for chat responses
87
  MAX_MAX_NEW_TOKENS = 2048
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
- # (Image generation code has been removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
@@ -200,7 +200,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
200
  if not text.strip():
201
  return None
202
  try:
203
- # Generate speech without internal progress calls (UI progress is handled externally)
204
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
  with torch.no_grad():
206
  generated_ids = orpheus_tts_model.generate(
@@ -233,7 +233,7 @@ def generate(
233
  repetition_penalty: float = 1.2,
234
  ):
235
  """
236
- Generates chatbot responses with support for video processing,
237
  TTS, and LLM-augmented TTS.
238
 
239
  Trigger commands:
@@ -335,39 +335,64 @@ def generate(
335
  yield gr.Audio(audio_output, autoplay=True)
336
  return
337
 
338
- # Default branch for regular chat (text without explicit TTS trigger)
339
  conversation = clean_chat_history(chat_history)
340
  conversation.append({"role": "user", "content": text})
341
- # Process using the DeepHermes LLM
342
- input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
343
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
344
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
345
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
346
- input_ids = input_ids.to(hermes_llm_model.device)
347
- streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
348
- generation_kwargs = {
349
- "input_ids": input_ids,
350
- "streamer": streamer,
351
- "max_new_tokens": max_new_tokens,
352
- "do_sample": True,
353
- "top_p": top_p,
354
- "top_k": top_k,
355
- "temperature": temperature,
356
- "num_beams": 1,
357
- "repetition_penalty": repetition_penalty,
358
- }
359
- t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
360
- t.start()
361
- outputs = []
362
- yield progress_bar_html("Processing with DeepHermes LLM")
363
- for new_text in streamer:
364
- outputs.append(new_text)
365
- yield "".join(outputs)
366
- final_response = "".join(outputs)
367
- yield final_response
368
- # Also convert the final response to speech using a default voice ("tara")
369
- audio_output = generate_speech(final_response, "tara", temperature, top_p, repetition_penalty, max_new_tokens)
370
- yield gr.Audio(audio_output, autoplay=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  # Gradio Interface
373
  demo = gr.ChatInterface(
@@ -386,6 +411,7 @@ demo = gr.ChatInterface(
386
  ["@josh-llm What causes rainbows to form?"],
387
  ["@dan-tts Yo, I’m Dan, [groan] and yes, I can even sound annoyed if I have to."],
388
  ["Write python program for array rotation"],
 
389
  ["@tara-tts Hey there, my name is Tara, [laugh] and I’m a speech generation model that can sound just like you!"],
390
  ["@tara-llm Who is Nikola Tesla, and why did he die?"],
391
  ["@emma-llm Explain the causes of rainbows"],
 
83
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
84
  print(f"Orpheus TTS model loaded to {tts_device}")
85
 
86
+ # Some global parameters for chat responses
87
  MAX_MAX_NEW_TOKENS = 2048
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
+ # (Image generation related code has been fully removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
 
200
  if not text.strip():
201
  return None
202
  try:
203
+ # Removed in-function progress calls to maintain UI consistency.
204
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
  with torch.no_grad():
206
  generated_ids = orpheus_tts_model.generate(
 
233
  repetition_penalty: float = 1.2,
234
  ):
235
  """
236
+ Generates chatbot responses with support for multimodal input, video processing,
237
  TTS, and LLM-augmented TTS.
238
 
239
  Trigger commands:
 
335
  yield gr.Audio(audio_output, autoplay=True)
336
  return
337
 
338
+ # Default branch for regular chat (text and multimodal without TTS).
339
  conversation = clean_chat_history(chat_history)
340
  conversation.append({"role": "user", "content": text})
341
+ # If files are provided, only non-image files (e.g. video) are processed via Qwen2VL.
342
+ if files:
343
+ # Process files using the processor (this branch no longer handles image generation)
344
+ if len(files) > 1:
345
+ inputs_list = [load_image(image) for image in files]
346
+ elif len(files) == 1:
347
+ inputs_list = [load_image(files[0])]
348
+ else:
349
+ inputs_list = []
350
+ messages = [{
351
+ "role": "user",
352
+ "content": [
353
+ *[{"type": "image", "image": img} for img in inputs_list],
354
+ {"type": "text", "text": text},
355
+ ]
356
+ }]
357
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
358
+ inputs = processor(text=[prompt_full], images=inputs_list, return_tensors="pt", padding=True).to("cuda")
359
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
361
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
362
+ thread.start()
363
+ buffer = ""
364
+ yield progress_bar_html("Processing with Qwen2VL")
365
+ for new_text in streamer:
366
+ buffer += new_text.replace("<|im_end|>", "")
367
+ time.sleep(0.01)
368
+ yield buffer
369
+ else:
370
+ input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
371
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
372
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
373
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
374
+ input_ids = input_ids.to(hermes_llm_model.device)
375
+ streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
376
+ generation_kwargs = {
377
+ "input_ids": input_ids,
378
+ "streamer": streamer,
379
+ "max_new_tokens": max_new_tokens,
380
+ "do_sample": True,
381
+ "top_p": top_p,
382
+ "top_k": top_k,
383
+ "temperature": temperature,
384
+ "num_beams": 1,
385
+ "repetition_penalty": repetition_penalty,
386
+ }
387
+ t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
388
+ t.start()
389
+ outputs = []
390
+ yield progress_bar_html("Processing with DeepHermes LLM")
391
+ for new_text in streamer:
392
+ outputs.append(new_text)
393
+ yield "".join(outputs)
394
+ final_response = "".join(outputs)
395
+ yield final_response
396
 
397
  # Gradio Interface
398
  demo = gr.ChatInterface(
 
411
  ["@josh-llm What causes rainbows to form?"],
412
  ["@dan-tts Yo, I’m Dan, [groan] and yes, I can even sound annoyed if I have to."],
413
  ["Write python program for array rotation"],
414
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
415
  ["@tara-tts Hey there, my name is Tara, [laugh] and I’m a speech generation model that can sound just like you!"],
416
  ["@tara-llm Who is Nikola Tesla, and why did he die?"],
417
  ["@emma-llm Explain the causes of rainbows"],