Rausda6 commited on
Commit
2114e35
·
verified ·
1 Parent(s): 7118f9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -316
app.py CHANGED
@@ -9,405 +9,481 @@ import os
9
  import time
10
  import mimetypes
11
  import torch
 
12
  from typing import List, Dict
13
- from transformers import AutoTokenizer, AutoModelForCausalLM
14
 
15
  # Constants
16
  MAX_FILE_SIZE_MB = 20
17
- MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
18
 
19
  MODEL_ID = "unsloth/gemma-3-1b-pt"
20
 
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_ID,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
- device_map="auto"
26
- ).eval()
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class PodcastGenerator:
30
  def __init__(self):
31
- pass
32
-
33
- async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict:
34
- example = """
35
- {
36
- "topic": "AGI",
37
- "podcast": [
38
- {
39
- "speaker": 2,
40
- "line": "So, AGI, huh? Seems like everyone's talking about it these days."
41
- },
42
- {
43
- "speaker": 1,
44
- "line": "Yeah, it's definitely having a moment, isn't it?"
45
- },
46
- {
47
- "speaker": 2,
48
- "line": "It is and for good reason, right? I mean, you've been digging into this stuff, listening to the podcasts and everything. What really stood out to you? What got you hooked?"
49
- },
50
- {
51
- "speaker": 1,
52
- "line": "I like that. It really is."
53
- },
54
- {
55
- "speaker": 2,
56
- "line": "And honestly, that's a responsibility that extends beyond just the researchers and the policymakers."
57
- },
58
- {
59
- "speaker": 1,
60
- "line": "100%"
61
- },
62
- {
63
- "speaker": 2,
64
- "line": "So to everyone listening out there I'll leave you with this. As AGI continues to develop, what role do you want to play in shaping its future?"
65
- },
66
- {
67
- "speaker": 1,
68
- "line": "That's a question worth pondering."
69
- },
70
- {
71
- "speaker": 2,
72
- "line": "It certainly is and on that note, we'll wrap up this deep dive. Thanks for listening, everyone."
73
- },
74
- {
75
- "speaker": 1,
76
- "line": "Peace."
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  }
78
- ]
79
- }
80
- """
81
 
82
  if language == "Auto Detect":
83
- language_instruction = "- The podcast MUST be in the same language as the user input."
84
  else:
85
- language_instruction = f"- The podcast MUST be in {language} language"
86
-
87
- system_prompt = f"""
88
- You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input.
89
- {language_instruction}
90
- - The podcast should have 2 speakers.
91
- - The podcast should be long.
92
- - Do not use names for the speakers.
93
- - The podcast should be interesting, lively, and engaging, and hook the listener from the start.
94
- - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast.
95
- - The script must be in JSON format.
96
- Follow this example structure:
97
- {example}
98
- """
99
- # Construct system and user prompt
100
- if prompt and file_obj:
101
- user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
102
- elif prompt:
103
- user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
104
- else:
105
- user_prompt = "Please generate a podcast script based on the uploaded file."
106
 
107
- # NOTE: file_obj cannot be passed to a text-only LLM
108
- if file_obj:
109
- print("Warning: Uploaded file is ignored in this version because external LLM does not support file input.")
 
110
 
111
- # Build prompt
112
- full_prompt = f"""{system_prompt}
113
 
114
- {user_prompt}
115
 
116
- Return the result strictly as a JSON object in the format:
117
- {{
118
- "topic": "{prompt}",
119
- "podcast": [
120
- {{ "speaker": 1, "line": "..." }},
121
- {{ "speaker": 2, "line": "..." }}
122
- ]
123
- }}
124
- """
125
 
126
  try:
127
  if progress:
128
  progress(0.3, "Generating podcast script...")
129
 
130
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
131
- output = model.generate(**inputs, max_new_tokens=1024)
132
- text = tokenizer.decode(output[0], skip_special_tokens=True)
133
-
134
- except Exception as e:
135
- raise Exception(f"Failed to generate podcast script: {e}")
136
-
137
- print(f"Generated podcast script:\n{text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if progress:
140
- progress(0.4, "Script generated successfully!")
141
 
142
- try:
143
- return json.loads(text)
144
- except json.JSONDecodeError:
145
- raise Exception("The model did not return valid JSON. Please refine the prompt.")
146
 
147
-
148
- async def _read_file_bytes(self, file_obj) -> bytes:
149
- """Read file bytes from a file object"""
150
- # Check file size before reading
151
- if hasattr(file_obj, 'size'):
152
- file_size = file_obj.size
153
- else:
154
- file_size = os.path.getsize(file_obj.name)
155
 
156
- if file_size > MAX_FILE_SIZE_BYTES:
157
- raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
158
 
159
- if hasattr(file_obj, 'read'):
160
- return file_obj.read()
161
- else:
162
- async with aiofiles.open(file_obj.name, 'rb') as f:
163
- return await f.read()
164
-
165
- def _get_mime_type(self, filename: str) -> str:
166
- """Determine MIME type based on file extension"""
167
- ext = os.path.splitext(filename)[1].lower()
168
- if ext == '.pdf':
169
- return "application/pdf"
170
- elif ext == '.txt':
171
- return "text/plain"
172
- else:
173
- # Fallback to the default mime type detector
174
- mime_type, _ = mimetypes.guess_type(filename)
175
- return mime_type or "application/octet-stream"
 
176
 
177
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
 
178
  voice = speaker1 if speaker == 1 else speaker2
179
  speech = edge_tts.Communicate(text, voice)
180
 
181
- temp_filename = f"temp_{uuid.uuid4()}.wav"
182
- try:
183
- # Add timeout to TTS generation
184
- await asyncio.wait_for(speech.save(temp_filename), timeout=30) # 30 seconds timeout
185
- return temp_filename
186
- except asyncio.TimeoutError:
187
- if os.path.exists(temp_filename):
188
- os.remove(temp_filename)
189
- raise Exception("Text-to-speech generation timed out. Please try with a shorter text.")
190
- except Exception as e:
191
- if os.path.exists(temp_filename):
192
- os.remove(temp_filename)
193
- raise e
 
 
 
 
 
 
 
 
 
194
 
195
  async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
 
196
  if progress:
197
  progress(0.9, "Combining audio files...")
198
 
199
- combined_audio = AudioSegment.empty()
200
- for audio_file in audio_files:
201
- combined_audio += AudioSegment.from_file(audio_file)
202
- os.remove(audio_file) # Clean up temporary files
203
-
204
- output_filename = f"output_{uuid.uuid4()}.wav"
205
- combined_audio.export(output_filename, format="wav")
206
-
207
- if progress:
208
- progress(1.0, "Podcast generated successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- return output_filename
 
 
 
 
 
211
 
212
  async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
 
213
  try:
214
  if progress:
215
  progress(0.1, "Starting podcast generation...")
216
-
217
- # Set overall timeout for the entire process
218
- return await asyncio.wait_for(
219
- self._generate_podcast_internal(input_text, language, speaker1, speaker2, file_obj, progress),
220
- timeout=600 # 10 minutes total timeout
221
- )
222
- except asyncio.TimeoutError:
223
- raise Exception("The podcast generation process timed out. Please try with shorter text or try again later.")
224
- except Exception as e:
225
- raise Exception(f"Error generating podcast: {str(e)}")
226
-
227
- async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
228
- if progress:
229
- progress(0.2, "Generating podcast script...")
230
-
231
- podcast_json = await self.generate_script(input_text, language, file_obj, progress)
232
-
233
- if progress:
234
- progress(0.5, "Converting text to speech...")
235
-
236
- # Process TTS in batches for concurrent processing
237
- audio_files = []
238
- total_lines = len(podcast_json['podcast'])
239
-
240
- # Define batch size to control concurrency
241
- batch_size = 10 # Adjust based on system resources
242
-
243
- # Process in batches
244
- for batch_start in range(0, total_lines, batch_size):
245
- batch_end = min(batch_start + batch_size, total_lines)
246
- batch = podcast_json['podcast'][batch_start:batch_end]
247
 
248
- # Create tasks for concurrent processing
249
- tts_tasks = []
250
- for item in batch:
251
- tts_task = self.tts_generate(item['line'], item['speaker'], speaker1, speaker2)
252
- tts_tasks.append(tts_task)
 
 
 
 
253
 
254
- try:
255
- # Process batch concurrently
256
- batch_results = await asyncio.gather(*tts_tasks, return_exceptions=True)
257
-
258
- # Check for exceptions and handle results
259
- for i, result in enumerate(batch_results):
260
- if isinstance(result, Exception):
261
- # Clean up any files already created
262
- for file in audio_files:
263
- if os.path.exists(file):
264
- os.remove(file)
265
- raise Exception(f"Error generating speech: {str(result)}")
266
- else:
267
- audio_files.append(result)
268
 
269
- # Update progress
270
- if progress:
271
- current_progress = 0.5 + (0.4 * (batch_end / total_lines))
272
- progress(current_progress, f"Processed {batch_end}/{total_lines} speech segments...")
273
-
274
- except Exception as e:
275
- # Clean up any files already created
276
- for file in audio_files:
277
- if os.path.exists(file):
278
- os.remove(file)
279
- raise Exception(f"Error in batch TTS generation: {str(e)}")
280
-
281
- combined_audio = await self.combine_audio_files(audio_files, progress)
282
- return combined_audio
283
 
284
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, progress=None) -> str:
285
- start_time = time.time()
 
 
 
 
286
 
287
- voice_names = {
288
- "Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
289
- "Ava - English (United States)": "en-US-AvaMultilingualNeural",
290
- "Brian - English (United States)": "en-US-BrianMultilingualNeural",
291
- "Emma - English (United States)": "en-US-EmmaMultilingualNeural",
292
- "Florian - German (Germany)": "de-DE-FlorianMultilingualNeural",
293
- "Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural",
294
- "Remy - French (France)": "fr-FR-RemyMultilingualNeural",
295
- "Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural"
296
- }
 
 
 
 
297
 
298
- speaker1 = voice_names[speaker1]
299
- speaker2 = voice_names[speaker2]
 
300
 
301
  try:
302
  if progress:
303
  progress(0.05, "Processing input...")
304
 
305
- api_key = "" # No API key needed for local model
 
 
 
 
 
 
 
 
306
 
307
  podcast_generator = PodcastGenerator()
308
- podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, input_file, progress)
 
 
309
 
310
  end_time = time.time()
311
- print(f"Total podcast generation time: {end_time - start_time:.2f} seconds")
312
- return podcast
313
-
314
  except Exception as e:
315
- # Ensure we show a user-friendly error
316
  error_msg = str(e)
317
- if "rate limit" in error_msg.lower():
318
- raise Exception("Rate limit exceeded. Please try again later or use your own API key.")
319
- elif "timeout" in error_msg.lower():
320
- raise Exception("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.")
321
- else:
322
- raise Exception(f"Error: {error_msg}")
323
 
324
- # Gradio UI
325
  def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2):
326
- progress = gr.Progress()
 
 
 
 
 
 
 
327
 
328
- # Handle the file if uploaded
329
- file_obj = None
330
- if input_file is not None:
331
- file_obj = input_file
332
 
333
- # Use the progress function from Gradio
334
- def progress_callback(value, text):
335
- progress(value, text)
336
-
337
- # Run the async function in the event loop
338
- result = asyncio.run(process_input(
339
- input_text,
340
- file_obj,
341
- language,
342
- speaker1,
343
- speaker2,
344
- progress_callback
345
- ))
346
-
347
- return result
348
 
349
- def main():
350
- # Define language options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  language_options = [
352
- "Auto Detect",
353
- "Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani",
354
- "Bahasa Indonesian", "Bangla", "Basque", "Bengali", "Bosnian", "Bulgarian",
355
- "Burmese", "Catalan", "Chinese Cantonese", "Chinese Mandarin",
356
- "Chinese Taiwanese", "Croatian", "Czech", "Danish", "Dutch", "English",
357
- "Estonian", "Filipino", "Finnish", "French", "Galician", "Georgian",
358
- "German", "Greek", "Hebrew", "Hindi", "Hungarian", "Icelandic", "Irish",
359
- "Italian", "Japanese", "Javanese", "Kannada", "Kazakh", "Khmer", "Korean",
360
- "Lao", "Latvian", "Lithuanian", "Macedonian", "Malay", "Malayalam",
361
- "Maltese", "Mongolian", "Nepali", "Norwegian Bokmål", "Pashto", "Persian",
362
- "Polish", "Portuguese", "Romanian", "Russian", "Serbian", "Slovak", "Slovene", "Somali", "Spanish", "Sundanese", "Swahili",
363
- "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian", "Urdu",
364
- "Uzbek", "Vietnamese", "Welsh", "Zulu"
365
  ]
366
 
367
- # Define voice options
368
- voice_options = [
369
- "Andrew - English (United States)",
370
- "Ava - English (United States)",
371
- "Brian - English (United States)",
372
- "Emma - English (United States)",
373
- "Florian - German (Germany)",
374
- "Seraphina - German (Germany)",
375
- "Remy - French (France)",
376
- "Vivienne - French (France)"
377
- ]
378
 
379
- # Create Gradio interface
380
- with gr.Blocks(title="PodcastGen 2🎙️") as demo:
381
- gr.Markdown("# PodcastGen 2🎙️")
382
- gr.Markdown("Generate a 2-speaker podcast from text input or documents!")
 
 
 
 
383
 
384
  with gr.Row():
385
  with gr.Column(scale=2):
386
- input_text = gr.Textbox(label="Input Text", lines=10, placeholder="Enter text for podcast generation...")
 
 
 
 
 
387
 
388
  with gr.Column(scale=1):
389
- input_file = gr.File(label="Or Upload a PDF or TXT file", file_types=[".pdf", ".txt"])
 
 
 
 
390
 
391
  with gr.Row():
392
- with gr.Column():
393
- language = gr.Dropdown(label="Language", choices=language_options, value="Auto Detect")
394
-
395
- with gr.Column():
396
- speaker1 = gr.Dropdown(label="Speaker 1 Voice", choices=voice_options, value="Andrew - English (United States)")
397
- speaker2 = gr.Dropdown(label="Speaker 2 Voice", choices=voice_options, value="Ava - English (United States)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
- generate_btn = gr.Button("Generate Podcast", variant="primary")
 
 
 
 
 
400
 
401
- with gr.Row():
402
- output_audio = gr.Audio(label="Generated Podcast", type="filepath", format="wav")
403
-
404
  generate_btn.click(
405
  fn=generate_podcast_gradio,
406
  inputs=[input_text, input_file, language, speaker1, speaker2],
407
- outputs=[output_audio]
 
408
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
- demo.launch()
411
 
412
  if __name__ == "__main__":
413
- main()
 
 
 
 
 
 
 
9
  import time
10
  import mimetypes
11
  import torch
12
+ import re
13
  from typing import List, Dict
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
15
 
16
  # Constants
17
  MAX_FILE_SIZE_MB = 20
18
+ MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
19
 
20
  MODEL_ID = "unsloth/gemma-3-1b-pt"
21
 
22
+ # Initialize model with proper error handling
23
+ try:
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_ID,
30
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
+ device_map="auto",
32
+ trust_remote_code=True
33
+ ).eval()
34
+
35
+ # Configure generation parameters
36
+ generation_config = GenerationConfig(
37
+ max_new_tokens=1024,
38
+ temperature=0.7,
39
+ top_p=0.9,
40
+ do_sample=True,
41
+ pad_token_id=tokenizer.pad_token_id,
42
+ eos_token_id=tokenizer.eos_token_id,
43
+ )
44
+
45
+ print(f"Model loaded successfully on device: {model.device}")
46
+
47
+ except Exception as e:
48
+ print(f"Model initialization error: {e}")
49
+ model = None
50
+ tokenizer = None
51
+ generation_config = None
52
 
53
  class PodcastGenerator:
54
  def __init__(self):
55
+ self.model = model
56
+ self.tokenizer = tokenizer
57
+ self.generation_config = generation_config
58
+
59
+ def extract_json_from_text(self, text: str) -> Dict:
60
+ """Extract JSON from model output using regex patterns"""
61
+ # Remove the input prompt from the output
62
+ # Look for JSON-like structures
63
+ json_patterns = [
64
+ r'\{[^{}]*"topic"[^{}]*"podcast"[^{}]*\[.*?\]\s*\}',
65
+ r'\{.*?"topic".*?"podcast".*?\[.*?\].*?\}',
66
+ ]
67
+
68
+ for pattern in json_patterns:
69
+ matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
70
+ for match in matches:
71
+ try:
72
+ # Clean up the match
73
+ cleaned_match = match.strip()
74
+ return json.loads(cleaned_match)
75
+ except json.JSONDecodeError:
76
+ continue
77
+
78
+ # If no valid JSON found, create a fallback structure
79
+ return self.create_fallback_podcast(text)
80
+
81
+ def create_fallback_podcast(self, text: str) -> Dict:
82
+ """Create a basic podcast structure when JSON parsing fails"""
83
+ # Extract meaningful sentences from the text
84
+ sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10]
85
+
86
+ if not sentences:
87
+ sentences = ["Let's discuss this interesting topic.", "That's a great point to consider."]
88
+
89
+ podcast_lines = []
90
+ for i, sentence in enumerate(sentences[:10]): # Limit to 10 exchanges
91
+ speaker = (i % 2) + 1
92
+ podcast_lines.append({
93
+ "speaker": speaker,
94
+ "line": sentence + "." if not sentence.endswith('.') else sentence
95
+ })
96
+
97
+ return {
98
+ "topic": "Generated Discussion",
99
+ "podcast": podcast_lines
100
+ }
101
+
102
+ async def generate_script(self, prompt: str, language: str, file_obj=None, progress=None) -> Dict:
103
+ if not self.model or not self.tokenizer:
104
+ raise Exception("Model not properly initialized. Please check model loading.")
105
+
106
+ example_json = {
107
+ "topic": "AGI",
108
+ "podcast": [
109
+ {"speaker": 1, "line": "So, AGI, huh? Seems like everyone's talking about it these days."},
110
+ {"speaker": 2, "line": "Yeah, it's definitely having a moment, isn't it?"},
111
+ {"speaker": 1, "line": "It really is. What got you hooked on this topic?"},
112
+ {"speaker": 2, "line": "The potential implications are fascinating and concerning at the same time."}
113
+ ]
114
  }
 
 
 
115
 
116
  if language == "Auto Detect":
117
+ language_instruction = "Use the same language as the input text"
118
  else:
119
+ language_instruction = f"Generate the podcast in {language} language"
120
+
121
+ # Simplified, more direct prompt
122
+ system_prompt = f"""Generate a podcast script as valid JSON. {language_instruction}.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ Requirements:
125
+ - Exactly 2 speakers (speaker 1 and 2)
126
+ - Natural, engaging conversation
127
+ - JSON format only
128
 
129
+ Example format:
130
+ {json.dumps(example_json, indent=2)}
131
 
132
+ Input topic: {prompt}
133
 
134
+ Generate JSON:"""
 
 
 
 
 
 
 
 
135
 
136
  try:
137
  if progress:
138
  progress(0.3, "Generating podcast script...")
139
 
140
+ # Tokenize with proper attention mask
141
+ inputs = self.tokenizer(
142
+ system_prompt,
143
+ return_tensors="pt",
144
+ padding=True,
145
+ truncation=True,
146
+ max_length=2048
147
+ )
148
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
149
+
150
+ # Generate with timeout
151
+ with torch.no_grad():
152
+ output = self.model.generate(
153
+ **inputs,
154
+ generation_config=self.generation_config,
155
+ pad_token_id=self.tokenizer.pad_token_id,
156
+ )
157
+
158
+ # Decode only the new tokens
159
+ generated_text = self.tokenizer.decode(
160
+ output[0][inputs['input_ids'].shape[1]:],
161
+ skip_special_tokens=True
162
+ )
163
 
164
+ print(f"Generated text: {generated_text[:500]}...")
 
165
 
166
+ if progress:
167
+ progress(0.4, "Processing generated script...")
 
 
168
 
169
+ # Extract JSON from the generated text
170
+ result = self.extract_json_from_text(generated_text)
 
 
 
 
 
 
171
 
172
+ if progress:
173
+ progress(0.5, "Script generated successfully!")
174
 
175
+ return result
176
+
177
+ except Exception as e:
178
+ print(f"Generation error: {e}")
179
+ # Return fallback podcast
180
+ return {
181
+ "topic": prompt or "Discussion",
182
+ "podcast": [
183
+ {"speaker": 1, "line": f"Welcome to our discussion about {prompt or 'this topic'}."},
184
+ {"speaker": 2, "line": "Thanks for having me. This is indeed an interesting subject."},
185
+ {"speaker": 1, "line": "Let's dive into the key points and explore different perspectives."},
186
+ {"speaker": 2, "line": "Absolutely. There's a lot to unpack here."},
187
+ {"speaker": 1, "line": "What aspects do you find most compelling?"},
188
+ {"speaker": 2, "line": "The implications and potential applications are fascinating."},
189
+ {"speaker": 1, "line": "That's a great point. Thanks for the insightful discussion."},
190
+ {"speaker": 2, "line": "Thank you. This has been a valuable conversation."}
191
+ ]
192
+ }
193
 
194
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
195
+ """Generate TTS audio with improved error handling"""
196
  voice = speaker1 if speaker == 1 else speaker2
197
  speech = edge_tts.Communicate(text, voice)
198
 
199
+ temp_filename = f"temp_audio_{uuid.uuid4()}.wav"
200
+ max_retries = 3
201
+
202
+ for attempt in range(max_retries):
203
+ try:
204
+ await asyncio.wait_for(speech.save(temp_filename), timeout=30)
205
+ if os.path.exists(temp_filename) and os.path.getsize(temp_filename) > 0:
206
+ return temp_filename
207
+ else:
208
+ raise Exception("Generated audio file is empty")
209
+ except asyncio.TimeoutError:
210
+ if os.path.exists(temp_filename):
211
+ os.remove(temp_filename)
212
+ if attempt == max_retries - 1:
213
+ raise Exception("TTS generation timed out after multiple attempts")
214
+ await asyncio.sleep(1) # Brief delay before retry
215
+ except Exception as e:
216
+ if os.path.exists(temp_filename):
217
+ os.remove(temp_filename)
218
+ if attempt == max_retries - 1:
219
+ raise Exception(f"TTS generation failed: {str(e)}")
220
+ await asyncio.sleep(1)
221
 
222
  async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
223
+ """Combine audio files with silence padding"""
224
  if progress:
225
  progress(0.9, "Combining audio files...")
226
 
227
+ try:
228
+ combined_audio = AudioSegment.empty()
229
+ silence_padding = AudioSegment.silent(duration=500) # 500ms silence
230
+
231
+ for i, audio_file in enumerate(audio_files):
232
+ try:
233
+ audio_segment = AudioSegment.from_file(audio_file)
234
+ combined_audio += audio_segment
235
+
236
+ # Add silence between speakers (except for the last file)
237
+ if i < len(audio_files) - 1:
238
+ combined_audio += silence_padding
239
+
240
+ except Exception as e:
241
+ print(f"Warning: Could not process audio file {audio_file}: {e}")
242
+ finally:
243
+ # Clean up temporary file
244
+ if os.path.exists(audio_file):
245
+ os.remove(audio_file)
246
+
247
+ if len(combined_audio) == 0:
248
+ raise Exception("No audio content generated")
249
+
250
+ output_filename = f"podcast_output_{uuid.uuid4()}.wav"
251
+ combined_audio.export(output_filename, format="wav")
252
+
253
+ if progress:
254
+ progress(1.0, "Podcast generated successfully!")
255
+
256
+ return output_filename
257
 
258
+ except Exception as e:
259
+ # Clean up any remaining temp files
260
+ for audio_file in audio_files:
261
+ if os.path.exists(audio_file):
262
+ os.remove(audio_file)
263
+ raise Exception(f"Audio combination failed: {str(e)}")
264
 
265
  async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, file_obj=None, progress=None) -> str:
266
+ """Main podcast generation pipeline with improved error handling"""
267
  try:
268
  if progress:
269
  progress(0.1, "Starting podcast generation...")
270
+
271
+ # Generate script
272
+ podcast_json = await self.generate_script(input_text, language, file_obj, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ if not podcast_json.get('podcast'):
275
+ raise Exception("No podcast content generated")
276
+
277
+ if progress:
278
+ progress(0.5, "Converting text to speech...")
279
+
280
+ # Generate TTS with sequential processing to avoid overload
281
+ audio_files = []
282
+ total_lines = len(podcast_json['podcast'])
283
 
284
+ for i, item in enumerate(podcast_json['podcast']):
285
+ try:
286
+ audio_file = await self.tts_generate(
287
+ item['line'],
288
+ item['speaker'],
289
+ speaker1,
290
+ speaker2
291
+ )
292
+ audio_files.append(audio_file)
293
+
294
+ # Update progress
295
+ if progress:
296
+ current_progress = 0.5 + (0.4 * (i + 1) / total_lines)
297
+ progress(current_progress, f"Generated speech {i + 1}/{total_lines}")
298
 
299
+ except Exception as e:
300
+ print(f"TTS error for line {i}: {e}")
301
+ # Continue with remaining lines
302
+ continue
 
 
 
 
 
 
 
 
 
 
303
 
304
+ if not audio_files:
305
+ raise Exception("No audio files generated successfully")
306
+
307
+ # Combine audio files
308
+ combined_audio = await self.combine_audio_files(audio_files, progress)
309
+ return combined_audio
310
 
311
+ except Exception as e:
312
+ raise Exception(f"Podcast generation failed: {str(e)}")
313
+
314
+ # Voice mapping
315
+ VOICE_MAPPING = {
316
+ "Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
317
+ "Ava - English (United States)": "en-US-AvaMultilingualNeural",
318
+ "Brian - English (United States)": "en-US-BrianMultilingualNeural",
319
+ "Emma - English (United States)": "en-US-EmmaMultilingualNeural",
320
+ "Florian - German (Germany)": "de-DE-FlorianMultilingualNeural",
321
+ "Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural",
322
+ "Remy - French (France)": "fr-FR-RemyMultilingualNeural",
323
+ "Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural"
324
+ }
325
 
326
+ async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, progress=None) -> str:
327
+ """Process input and generate podcast"""
328
+ start_time = time.time()
329
 
330
  try:
331
  if progress:
332
  progress(0.05, "Processing input...")
333
 
334
+ # Map speaker names to voice IDs
335
+ speaker1_voice = VOICE_MAPPING.get(speaker1, "en-US-AndrewMultilingualNeural")
336
+ speaker2_voice = VOICE_MAPPING.get(speaker2, "en-US-AvaMultilingualNeural")
337
+
338
+ # Validate input
339
+ if not input_text or input_text.strip() == "":
340
+ if input_file is None:
341
+ raise Exception("Please provide either text input or upload a file")
342
+ # TODO: Add file processing logic here if needed
343
 
344
  podcast_generator = PodcastGenerator()
345
+ result = await podcast_generator.generate_podcast(
346
+ input_text, language, speaker1_voice, speaker2_voice, input_file, progress
347
+ )
348
 
349
  end_time = time.time()
350
+ print(f"Total generation time: {end_time - start_time:.2f} seconds")
351
+ return result
352
+
353
  except Exception as e:
 
354
  error_msg = str(e)
355
+ print(f"Processing error: {error_msg}")
356
+ raise Exception(f"Generation failed: {error_msg}")
 
 
 
 
357
 
 
358
  def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2):
359
+ """Gradio interface function with proper error handling"""
360
+ try:
361
+ # Validate inputs
362
+ if not input_text and input_file is None:
363
+ return None
364
+
365
+ if input_text and len(input_text.strip()) == 0:
366
+ input_text = None
367
 
368
+ # Create a simple progress tracker
369
+ progress_history = []
 
 
370
 
371
+ def progress_callback(value, text):
372
+ progress_history.append(f"{value:.1%}: {text}")
373
+ print(f"Progress: {value:.1%} - {text}")
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ # Run the async function
376
+ loop = asyncio.new_event_loop()
377
+ asyncio.set_event_loop(loop)
378
+ try:
379
+ result = loop.run_until_complete(
380
+ process_input(input_text, input_file, language, speaker1, speaker2, progress_callback)
381
+ )
382
+ return result
383
+ finally:
384
+ loop.close()
385
+
386
+ except Exception as e:
387
+ print(f"Gradio function error: {e}")
388
+ raise gr.Error(f"Failed to generate podcast: {str(e)}")
389
+
390
+ def create_interface():
391
+ """Create the Gradio interface with proper component configuration"""
392
  language_options = [
393
+ "Auto Detect", "English", "German", "French", "Spanish", "Italian",
394
+ "Portuguese", "Dutch", "Russian", "Chinese", "Japanese", "Korean"
 
 
 
 
 
 
 
 
 
 
 
395
  ]
396
 
397
+ voice_options = list(VOICE_MAPPING.keys())
 
 
 
 
 
 
 
 
 
 
398
 
399
+ with gr.Blocks(
400
+ title="PodcastGen 2🎙️",
401
+ theme=gr.themes.Soft(),
402
+ css=".gradio-container {max-width: 1200px; margin: auto;}"
403
+ ) as demo:
404
+
405
+ gr.Markdown("# 🎙️ PodcastGen 2")
406
+ gr.Markdown("Generate professional 2-speaker podcasts from text input!")
407
 
408
  with gr.Row():
409
  with gr.Column(scale=2):
410
+ input_text = gr.Textbox(
411
+ label="Input Text",
412
+ lines=8,
413
+ placeholder="Enter your topic or text for podcast generation...",
414
+ info="Describe what you want the podcast to discuss"
415
+ )
416
 
417
  with gr.Column(scale=1):
418
+ input_file = gr.File(
419
+ label="Upload File (Optional)",
420
+ file_types=[".pdf", ".txt"],
421
+ info=f"Max size: {MAX_FILE_SIZE_MB}MB"
422
+ )
423
 
424
  with gr.Row():
425
+ language = gr.Dropdown(
426
+ label="Language",
427
+ choices=language_options,
428
+ value="Auto Detect",
429
+ info="Select output language"
430
+ )
431
+
432
+ speaker1 = gr.Dropdown(
433
+ label="Speaker 1 Voice",
434
+ choices=voice_options,
435
+ value="Andrew - English (United States)"
436
+ )
437
+
438
+ speaker2 = gr.Dropdown(
439
+ label="Speaker 2 Voice",
440
+ choices=voice_options,
441
+ value="Ava - English (United States)"
442
+ )
443
+
444
+ generate_btn = gr.Button(
445
+ "🎙️ Generate Podcast",
446
+ variant="primary",
447
+ size="lg"
448
+ )
449
 
450
+ output_audio = gr.Audio(
451
+ label="Generated Podcast",
452
+ type="filepath",
453
+ format="wav",
454
+ show_download_button=True
455
+ )
456
 
457
+ # Connect the interface
 
 
458
  generate_btn.click(
459
  fn=generate_podcast_gradio,
460
  inputs=[input_text, input_file, language, speaker1, speaker2],
461
+ outputs=[output_audio],
462
+ show_progress=True
463
  )
464
+
465
+ # Add usage instructions
466
+ with gr.Accordion("Usage Instructions", open=False):
467
+ gr.Markdown("""
468
+ ### How to use:
469
+ 1. **Input**: Enter your topic or text in the text box, or upload a PDF/TXT file
470
+ 2. **Language**: Choose the output language (Auto Detect recommended)
471
+ 3. **Voices**: Select different voices for Speaker 1 and Speaker 2
472
+ 4. **Generate**: Click the button and wait for processing
473
+
474
+ ### Tips:
475
+ - Provide clear, specific topics for better results
476
+ - The AI will create a natural conversation between two speakers
477
+ - Generation may take 1-3 minutes depending on text length
478
+ """)
479
 
480
+ return demo
481
 
482
  if __name__ == "__main__":
483
+ demo = create_interface()
484
+ demo.launch(
485
+ server_name="0.0.0.0",
486
+ server_port=7860,
487
+ show_error=True,
488
+ share=False
489
+ )