prithivMLmods commited on
Commit
0d5b113
·
verified ·
1 Parent(s): a5f372d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -61
app.py CHANGED
@@ -83,12 +83,12 @@ orpheus_tts_model.to(tts_device)
83
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
84
  print(f"Orpheus TTS model loaded to {tts_device}")
85
 
86
- # Some global parameters for chat responses
87
  MAX_MAX_NEW_TOKENS = 2048
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
- # (Image generation related code has been fully removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
@@ -200,7 +200,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
200
  if not text.strip():
201
  return None
202
  try:
203
- # Removed in-function progress calls to maintain UI consistency.
204
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
  with torch.no_grad():
206
  generated_ids = orpheus_tts_model.generate(
@@ -233,7 +233,7 @@ def generate(
233
  repetition_penalty: float = 1.2,
234
  ):
235
  """
236
- Generates chatbot responses with support for multimodal input, video processing,
237
  TTS, and LLM-augmented TTS.
238
 
239
  Trigger commands:
@@ -335,64 +335,39 @@ def generate(
335
  yield gr.Audio(audio_output, autoplay=True)
336
  return
337
 
338
- # Default branch for regular chat (text and multimodal without TTS).
339
  conversation = clean_chat_history(chat_history)
340
  conversation.append({"role": "user", "content": text})
341
- # If files are provided, only non-image files (e.g. video) are processed via Qwen2VL.
342
- if files:
343
- # Process files using the processor (this branch no longer handles image generation)
344
- if len(files) > 1:
345
- inputs_list = [load_image(image) for image in files]
346
- elif len(files) == 1:
347
- inputs_list = [load_image(files[0])]
348
- else:
349
- inputs_list = []
350
- messages = [{
351
- "role": "user",
352
- "content": [
353
- *[{"type": "image", "image": img} for img in inputs_list],
354
- {"type": "text", "text": text},
355
- ]
356
- }]
357
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
358
- inputs = processor(text=[prompt_full], images=inputs_list, return_tensors="pt", padding=True).to("cuda")
359
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
361
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
362
- thread.start()
363
- buffer = ""
364
- yield progress_bar_html("Processing with Qwen2VL")
365
- for new_text in streamer:
366
- buffer += new_text.replace("<|im_end|>", "")
367
- time.sleep(0.01)
368
- yield buffer
369
- else:
370
- input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
371
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
372
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
373
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
374
- input_ids = input_ids.to(hermes_llm_model.device)
375
- streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
376
- generation_kwargs = {
377
- "input_ids": input_ids,
378
- "streamer": streamer,
379
- "max_new_tokens": max_new_tokens,
380
- "do_sample": True,
381
- "top_p": top_p,
382
- "top_k": top_k,
383
- "temperature": temperature,
384
- "num_beams": 1,
385
- "repetition_penalty": repetition_penalty,
386
- }
387
- t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
388
- t.start()
389
- outputs = []
390
- yield progress_bar_html("Processing with DeepHermes LLM")
391
- for new_text in streamer:
392
- outputs.append(new_text)
393
- yield "".join(outputs)
394
- final_response = "".join(outputs)
395
- yield final_response
396
 
397
  # Gradio Interface
398
  demo = gr.ChatInterface(
@@ -411,7 +386,6 @@ demo = gr.ChatInterface(
411
  ["@josh-llm What causes rainbows to form?"],
412
  ["@dan-tts Yo, I’m Dan, [groan] and yes, I can even sound annoyed if I have to."],
413
  ["Write python program for array rotation"],
414
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
415
  ["@tara-tts Hey there, my name is Tara, [laugh] and I’m a speech generation model that can sound just like you!"],
416
  ["@tara-llm Who is Nikola Tesla, and why did he die?"],
417
  ["@emma-llm Explain the causes of rainbows"],
 
83
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
84
  print(f"Orpheus TTS model loaded to {tts_device}")
85
 
86
+ # Global parameters for chat responses
87
  MAX_MAX_NEW_TOKENS = 2048
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
+ # (Image generation code has been removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
 
200
  if not text.strip():
201
  return None
202
  try:
203
+ # Generate speech without internal progress calls (UI progress is handled externally)
204
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
  with torch.no_grad():
206
  generated_ids = orpheus_tts_model.generate(
 
233
  repetition_penalty: float = 1.2,
234
  ):
235
  """
236
+ Generates chatbot responses with support for video processing,
237
  TTS, and LLM-augmented TTS.
238
 
239
  Trigger commands:
 
335
  yield gr.Audio(audio_output, autoplay=True)
336
  return
337
 
338
+ # Default branch for regular chat (text without explicit TTS trigger)
339
  conversation = clean_chat_history(chat_history)
340
  conversation.append({"role": "user", "content": text})
341
+ # Process using the DeepHermes LLM
342
+ input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
343
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
344
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
345
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
346
+ input_ids = input_ids.to(hermes_llm_model.device)
347
+ streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
348
+ generation_kwargs = {
349
+ "input_ids": input_ids,
350
+ "streamer": streamer,
351
+ "max_new_tokens": max_new_tokens,
352
+ "do_sample": True,
353
+ "top_p": top_p,
354
+ "top_k": top_k,
355
+ "temperature": temperature,
356
+ "num_beams": 1,
357
+ "repetition_penalty": repetition_penalty,
358
+ }
359
+ t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
360
+ t.start()
361
+ outputs = []
362
+ yield progress_bar_html("Processing with DeepHermes LLM")
363
+ for new_text in streamer:
364
+ outputs.append(new_text)
365
+ yield "".join(outputs)
366
+ final_response = "".join(outputs)
367
+ yield final_response
368
+ # Also convert the final response to speech using a default voice ("tara")
369
+ audio_output = generate_speech(final_response, "tara", temperature, top_p, repetition_penalty, max_new_tokens)
370
+ yield gr.Audio(audio_output, autoplay=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  # Gradio Interface
373
  demo = gr.ChatInterface(
 
386
  ["@josh-llm What causes rainbows to form?"],
387
  ["@dan-tts Yo, I’m Dan, [groan] and yes, I can even sound annoyed if I have to."],
388
  ["Write python program for array rotation"],
 
389
  ["@tara-tts Hey there, my name is Tara, [laugh] and I’m a speech generation model that can sound just like you!"],
390
  ["@tara-llm Who is Nikola Tesla, and why did he die?"],
391
  ["@emma-llm Explain the causes of rainbows"],