CamiloVega commited on
Commit
432f68e
Β·
verified Β·
1 Parent(s): 833724b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +951 -406
app.py CHANGED
@@ -10,15 +10,15 @@ import torch
10
  import whisper
11
  import subprocess
12
  from pydub import AudioSegment
13
- import fitz
14
  import docx
15
  import yt_dlp
16
  from functools import lru_cache
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
- from unsloth import FastLanguageModel
21
- from transformers import AutoTokenizer
22
 
23
  # Configure logging
24
  logging.basicConfig(
@@ -30,626 +30,1171 @@ logger = logging.getLogger(__name__)
30
  # Login to Hugging Face Hub if token is available
31
  HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
32
  if HUGGINGFACE_TOKEN:
33
- login(token=HUGGINGFACE_TOKEN)
 
 
 
 
34
 
35
  class ModelManager:
36
  _instance = None
37
-
38
  def __new__(cls):
39
  if cls._instance is None:
40
  cls._instance = super(ModelManager, cls).__new__(cls)
41
  cls._instance._initialized = False
42
  return cls._instance
43
-
44
  def __init__(self):
45
  if not self._initialized:
46
  self.tokenizer = None
47
  self.model = None
48
- self.pipeline = None
49
  self.whisper_model = None
50
  self._initialized = True
51
  self.last_used = time.time()
52
-
53
- @spaces.GPU()
 
 
54
  def initialize_llm(self):
55
- """Initialize LLM model with Unsloth optimization"""
 
 
 
 
 
 
 
 
 
56
  try:
 
57
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
58
-
59
- logger.info("Loading Unsloth-optimized model...")
60
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
61
- model_name = MODEL_NAME,
62
- max_seq_length = 2048,
63
- dtype = torch.float16,
64
- load_in_4bit = True,
65
- token = HUGGINGFACE_TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
-
68
- # Enable LoRA for better ZeroGPU performance
69
- self.model = FastLanguageModel.get_peft_model(
70
- self.model,
71
- r = 16,
72
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
73
- "gate_proj", "up_proj", "down_proj"],
74
- lora_alpha = 16,
75
- lora_dropout = 0,
76
- bias = "none",
77
- use_gradient_checkpointing = True,
78
- random_state = 3407,
79
- max_seq_length = 2048,
80
  )
81
-
82
- logger.info("LLM initialized successfully with Unsloth")
83
  self.last_used = time.time()
 
84
  return True
85
-
86
  except Exception as e:
87
  logger.error(f"Error initializing LLM: {str(e)}")
88
- raise
 
 
 
 
 
 
 
 
 
89
 
90
- @spaces.GPU()
91
  def initialize_whisper(self):
92
- """Initialize Whisper model with safety fix"""
 
 
 
 
 
 
 
 
 
93
  try:
94
  logger.info("Loading Whisper model...")
95
- # Load with weights_only=True for security
 
 
 
 
96
  self.whisper_model = whisper.load_model(
97
- "tiny",
98
  device="cuda" if torch.cuda.is_available() else "cpu",
99
- download_root="/tmp/whisper",
100
- weights_only=True # Security fix
101
  )
102
  logger.info("Whisper model initialized successfully")
103
  self.last_used = time.time()
 
104
  return True
105
  except Exception as e:
106
  logger.error(f"Error initializing Whisper: {str(e)}")
 
 
 
 
 
 
107
  raise
108
 
109
  def check_llm_initialized(self):
110
  """Check if LLM is initialized and initialize if needed"""
111
- if self.tokenizer is None or self.model is None:
112
  logger.info("LLM not initialized, initializing...")
113
- self.initialize_llm()
 
 
 
 
 
 
 
114
  self.last_used = time.time()
115
-
116
  def check_whisper_initialized(self):
117
  """Check if Whisper model is initialized and initialize if needed"""
118
  if self.whisper_model is None:
119
  logger.info("Whisper model not initialized, initializing...")
120
- self.initialize_whisper()
 
 
 
 
 
 
121
  self.last_used = time.time()
122
-
123
  def reset_models(self, force=False):
124
  """Reset models to free memory if they haven't been used recently"""
125
  current_time = time.time()
126
- if force or (current_time - self.last_used > 600):
 
127
  try:
128
  logger.info("Resetting models to free memory...")
129
-
130
- if self.model is not None:
 
131
  del self.model
132
-
133
- if self.tokenizer is not None:
 
 
134
  del self.tokenizer
135
-
136
- if self.whisper_model is not None:
 
 
 
 
 
 
 
137
  del self.whisper_model
138
-
139
- self.tokenizer = None
140
- self.model = None
141
- self.whisper_model = None
142
-
143
  if torch.cuda.is_available():
144
  torch.cuda.empty_cache()
145
- torch.cuda.synchronize()
146
-
 
147
  gc.collect()
148
- logger.info("Models reset successfully")
149
-
 
150
  except Exception as e:
151
  logger.error(f"Error resetting models: {str(e)}")
 
152
 
 
153
  model_manager = ModelManager()
154
 
155
- @lru_cache(maxsize=32)
156
  def download_social_media_video(url):
157
- """Download a video from social media."""
 
 
 
158
  ydl_opts = {
159
  'format': 'bestaudio/best',
160
  'postprocessors': [{
161
  'key': 'FFmpegExtractAudio',
162
  'preferredcodec': 'mp3',
163
- 'preferredquality': '192',
164
  }],
165
- 'outtmpl': '%(id)s.%(ext)s',
 
 
 
 
 
166
  }
167
  try:
 
168
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
169
  info_dict = ydl.extract_info(url, download=True)
170
- audio_file = f"{info_dict['id']}.mp3"
171
- logger.info(f"Video downloaded successfully: {audio_file}")
172
- return audio_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
- logger.error(f"Error downloading video: {str(e)}")
175
- raise
 
 
 
 
 
 
 
 
176
 
177
- def convert_video_to_audio(video_file):
178
  """Convert a video file to audio using ffmpeg directly."""
179
  try:
 
180
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
181
- output_file = temp_file.name
182
-
 
 
 
183
  command = [
184
- "ffmpeg",
185
- "-i", video_file,
186
- "-q:a", "0",
187
- "-map", "a",
188
- "-vn",
189
- output_file,
190
- "-y"
 
 
 
191
  ]
192
-
193
- subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
194
-
195
- logger.info(f"Video converted to audio: {output_file}")
196
- return output_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  except Exception as e:
198
- logger.error(f"Error converting video: {str(e)}")
199
- raise
 
 
 
 
200
 
201
- def preprocess_audio(audio_file):
202
- """Preprocess the audio file to improve quality."""
203
  try:
204
- audio = AudioSegment.from_file(audio_file)
205
- audio = audio.apply_gain(-audio.dBFS + (-20))
 
 
 
 
 
 
 
206
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
207
- audio.export(temp_file.name, format="mp3")
208
- logger.info(f"Audio preprocessed: {temp_file.name}")
209
- return temp_file.name
 
 
210
  except Exception as e:
211
- logger.error(f"Error preprocessing audio: {str(e)}")
 
 
 
212
  raise
213
 
214
- @spaces.GPU()
215
- def transcribe_audio(file):
216
- """Transcribe an audio or video file."""
 
 
 
 
217
  try:
218
  model_manager.check_whisper_initialized()
219
-
220
- if isinstance(file, str) and file.startswith('http'):
221
- file_path = download_social_media_video(file)
222
- elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
223
- file_path = convert_video_to_audio(file)
224
- elif file is not None:
225
- file_path = preprocess_audio(file.name)
 
 
 
 
 
 
 
 
 
 
226
  else:
227
- return ""
228
-
229
- logger.info(f"Transcribing audio: {file_path}")
230
- if not os.path.exists(file_path):
231
- raise FileNotFoundError(f"Audio file not found: {file_path}")
232
-
233
- with torch.inference_mode():
234
- result = model_manager.whisper_model.transcribe(file_path)
235
-
236
- transcription = result.get("text", "Error in transcription")
237
- logger.info(f"Transcription completed: {transcription[:50]}...")
238
-
 
 
 
 
 
 
 
239
  try:
240
- if os.path.exists(file_path):
241
- os.remove(file_path)
242
- except Exception as e:
243
- logger.warning(f"Could not remove temp file {file_path}: {str(e)}")
244
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  except Exception as e:
247
- logger.error(f"Error transcribing: {str(e)}")
248
- return f"Error processing the file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- @lru_cache(maxsize=32)
251
  def read_document(document_path):
252
- """Read the content of a document."""
253
  try:
254
- if document_path.endswith(".pdf"):
 
 
 
 
 
 
255
  doc = fitz.open(document_path)
256
- return "\n".join([page.get_text() for page in doc])
257
- elif document_path.endswith(".docx"):
 
 
258
  doc = docx.Document(document_path)
259
  return "\n".join([paragraph.text for paragraph in doc.paragraphs])
260
- elif document_path.endswith((".xlsx", ".xls")):
261
- return pd.read_excel(document_path).to_string()
262
- elif document_path.endswith(".csv"):
263
- return pd.read_csv(document_path).to_string()
 
 
 
 
 
 
 
 
 
 
 
 
264
  else:
 
265
  return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
 
 
 
266
  except Exception as e:
267
- logger.error(f"Error reading document: {str(e)}")
 
268
  return f"Error reading document: {str(e)}"
269
 
270
- @lru_cache(maxsize=32)
271
  def read_url(url):
272
- """Read the content of a URL."""
273
- if not url or url.strip() == "":
274
- return ""
275
-
 
276
  try:
 
277
  headers = {
278
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
279
  }
280
- response = requests.get(url, headers=headers, timeout=15)
281
- response.raise_for_status()
 
 
 
 
 
 
 
 
282
  soup = BeautifulSoup(response.content, 'html.parser')
283
-
284
- for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]):
 
285
  element.extract()
286
-
287
- main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"])
 
 
 
 
 
 
 
288
  if main_content:
289
  text = main_content.get_text(separator='\n', strip=True)
290
  else:
291
- text = soup.get_text(separator='\n', strip=True)
292
-
293
- lines = [line.strip() for line in text.split('\n') if line.strip()]
294
- text = '\n'.join(lines)
295
-
296
- return text[:10000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  except Exception as e:
298
- logger.error(f"Error reading URL: {str(e)}")
299
- return f"Error reading URL: {str(e)}"
 
300
 
301
- def process_social_content(url):
302
- """Process social media content."""
303
- if not url or url.strip() == "":
 
304
  return None
305
-
 
 
 
 
 
 
306
  try:
307
  text_content = read_url(url)
308
- try:
309
- video_content = transcribe_audio(url)
310
- except Exception as e:
311
- logger.error(f"Error processing video content: {str(e)}")
312
- video_content = None
 
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  return {
315
- "text": text_content,
316
- "video": video_content
317
  }
318
- except Exception as e:
319
- logger.error(f"Error processing social content: {str(e)}")
320
- return None
 
321
 
322
- @spaces.GPU()
323
  def generate_news(instructions, facts, size, tone, *args):
324
- """Generate a news article based on provided data"""
 
 
325
  try:
326
- if isinstance(size, float):
327
- size = int(size)
328
- elif not isinstance(size, int):
 
 
329
  size = 250
330
-
331
- model_manager.check_llm_initialized()
332
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  knowledge_base = {
334
- "instructions": instructions or "",
335
- "facts": facts or "",
336
  "document_content": [],
337
  "audio_data": [],
338
  "url_content": [],
339
  "social_content": []
340
  }
 
341
 
342
- num_audios = 5 * 3
343
- num_social_urls = 3 * 3
344
- num_urls = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
- args = list(args)
347
-
348
- while len(args) < (num_audios + num_social_urls + num_urls + 5):
349
- args.append("")
350
-
351
- audios = args[:num_audios]
352
- social_urls = args[num_audios:num_audios+num_social_urls]
353
- urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
354
- documents = args[num_audios+num_social_urls+num_urls:]
355
-
356
- logger.info("Processing URLs...")
357
- for url in urls:
358
- if url and isinstance(url, str) and url.strip():
359
- content = read_url(url)
360
- if content and not content.startswith("Error"):
361
- knowledge_base["url_content"].append(content)
362
-
363
- logger.info("Processing documents...")
364
- for document in documents:
365
- if document and hasattr(document, 'name'):
366
- content = read_document(document.name)
367
- if content and not content.startswith("Error"):
368
- knowledge_base["document_content"].append(content)
369
-
370
- logger.info("Processing audio/video files...")
371
- for i in range(0, len(audios), 3):
372
- if i+2 < len(audios):
373
- audio_file, name, position = audios[i:i+3]
374
- if audio_file and hasattr(audio_file, 'name'):
375
- knowledge_base["audio_data"].append({
376
- "audio": audio_file,
377
- "name": name or "Unknown",
378
- "position": position or "Not specified"
379
- })
380
-
381
- logger.info("Processing social media content...")
382
- for i in range(0, len(social_urls), 3):
383
- if i+2 < len(social_urls):
384
- social_url, social_name, social_context = social_urls[i:i+3]
385
- if social_url and isinstance(social_url, str) and social_url.strip():
386
- social_content = process_social_content(social_url)
387
- if social_content:
388
- knowledge_base["social_content"].append({
389
  "url": social_url,
390
- "name": social_name or "Unknown",
391
- "context": social_context or "Not specified",
392
- "text": social_content.get("text", ""),
393
- "video": social_content.get("video", "")
394
  })
 
 
 
 
 
 
395
 
396
- transcriptions_text = ""
397
- raw_transcriptions = ""
398
 
399
- logger.info("Transcribing audio...")
400
- for idx, data in enumerate(knowledge_base["audio_data"]):
401
- if data["audio"] is not None:
402
- transcription = transcribe_audio(data["audio"])
403
- if transcription and not transcription.startswith("Error"):
404
- transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n'
405
- raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
 
407
  for idx, data in enumerate(knowledge_base["social_content"]):
408
- if data["text"] and not str(data["text"]).startswith("Error"):
409
- text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"]
410
- social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n'
411
- transcriptions_text += social_text
412
- raw_transcriptions += social_text
413
-
414
- if data["video"] and not str(data["video"]).startswith("Error"):
415
- video_transcription = f'[Social media {idx+1} - video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n\n'
416
- transcriptions_text += video_transcription
417
- raw_transcriptions += video_transcription
418
-
419
- document_summaries = []
420
- for idx, doc in enumerate(knowledge_base["document_content"]):
421
- if len(doc) > 1000:
422
- doc_excerpt = doc[:1000] + "... [document continues]"
423
- else:
424
- doc_excerpt = doc
425
- document_summaries.append(f"[Document {idx+1}]: {doc_excerpt}")
426
-
427
- document_content = "\n\n".join(document_summaries)
428
-
429
- url_summaries = []
430
- for idx, url_content in enumerate(knowledge_base["url_content"]):
431
- if len(url_content) > 1000:
432
- url_excerpt = url_content[:1000] + "... [content continues]"
433
- else:
434
- url_excerpt = url_content
435
- url_summaries.append(f"[URL {idx+1}]: {url_excerpt}")
436
-
437
- url_content = "\n\n".join(url_summaries)
438
 
439
- prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information:
 
440
 
441
- Instructions: {knowledge_base["instructions"]}
 
442
 
443
- Facts: {knowledge_base["facts"]}
444
 
445
- Additional content from documents:
446
- {document_content}
447
 
448
- Additional content from URLs:
449
- {url_content}
450
 
451
- Use these transcriptions as direct and indirect quotes:
452
- {transcriptions_text}
453
 
454
- Follow these requirements:
455
- - Write a title
456
- - Write a 15-word hook that complements the title
457
- - Write the body with approximately {size} words
458
- - Use a {tone} tone
459
- - Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
460
- - Use at least 80% direct quotes (in quotation marks)
461
- - Use proper journalistic style
462
- - Do not invent information
463
- - Be rigorous with the provided facts [/INST]"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  try:
466
- logger.info("Generating news article...")
467
-
468
- max_length = min(len(prompt.split()) + size * 2, 2048)
469
-
470
- inputs = model_manager.tokenizer(
471
  prompt,
472
- return_tensors = "pt",
473
- padding = True,
474
- truncation = True,
475
- max_length = 2048,
476
- ).to("cuda")
477
-
478
- outputs = model_manager.model.generate(
479
- **inputs,
480
- max_new_tokens = size + 100,
481
- temperature = 0.7,
482
- do_sample = True,
483
- pad_token_id = model_manager.tokenizer.eos_token_id,
484
  )
485
-
486
- generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens = True)
487
-
488
- if "[/INST]" in generated_text:
489
- news_article = generated_text.split("[/INST]")[1].strip()
 
 
 
 
 
 
 
 
490
  else:
491
- prompt_fragment = " ".join(prompt.split()[:50])
492
- if prompt_fragment in generated_text:
493
- news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
494
- else:
495
- news_article = generated_text
496
-
497
- logger.info(f"News generation completed: {len(news_article)} chars")
498
-
 
 
 
 
 
 
 
 
 
 
499
  except Exception as gen_error:
500
- logger.error(f"Error in text generation: {str(gen_error)}")
501
- raise
502
-
503
- return news_article, raw_transcriptions
 
 
 
 
 
504
 
505
  except Exception as e:
506
- logger.error(f"Error generating news: {str(e)}")
 
 
 
507
  try:
508
  model_manager.reset_models(force=True)
509
  except Exception as reset_error:
510
- logger.error(f"Failed to reset models: {str(reset_error)}")
511
- return f"Error generating news: {str(e)}", "Error processing transcriptions."
 
 
 
512
 
513
  def create_demo():
 
514
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
515
  gr.Markdown("# πŸ“° NewsIA - AI News Generator")
516
- gr.Markdown("Create professional news articles from multiple sources.")
517
-
 
 
 
518
  with gr.Row():
519
  with gr.Column(scale=2):
520
  instructions = gr.Textbox(
521
- label="News Instructions",
522
- placeholder="Enter specific instructions for news generation",
523
- lines=2
 
524
  )
 
 
525
  facts = gr.Textbox(
526
- label="Key Facts",
527
- placeholder="Describe the most important facts to include",
528
- lines=4
 
529
  )
530
-
 
531
  with gr.Row():
532
- size = gr.Slider(
533
  label="Approximate Length (words)",
534
  minimum=100,
535
- maximum=500,
536
  value=250,
537
  step=50
538
  )
539
- tone = gr.Dropdown(
540
- label="News Tone",
541
- choices=["serious", "neutral", "funny", "formal", "informal", "urgent"],
 
 
542
  value="neutral"
543
  )
 
544
 
545
  with gr.Column(scale=3):
546
- inputs_list = []
547
- inputs_list.extend([instructions, facts, size, tone])
548
-
549
  with gr.Tabs():
550
  with gr.TabItem("πŸ“ Documents"):
551
- documents = []
 
552
  for i in range(1, 6):
553
- doc = gr.File(
554
  label=f"Document {i}",
555
- file_types=["pdf", "docx", "xlsx", "csv"],
556
- file_count="single"
557
  )
558
- documents.append(doc)
559
- inputs_list.append(doc)
560
 
561
  with gr.TabItem("πŸ”Š Audio/Video"):
562
- for i in range(1, 6):
 
 
563
  with gr.Group():
564
  gr.Markdown(f"**Source {i}**")
565
- file = gr.File(
566
- label=f"Audio/Video {i}",
567
  file_types=["audio", "video"]
568
  )
569
  with gr.Row():
570
- name = gr.Textbox(
571
- label="Name",
572
- placeholder="Interviewee name"
 
573
  )
574
- position = gr.Textbox(
575
- label="Position/Role",
576
- placeholder="Position or role"
 
577
  )
578
- inputs_list.extend([file, name, position])
 
 
 
579
 
580
  with gr.TabItem("🌐 URLs"):
581
- for i in range(1, 6):
582
- url = gr.Textbox(
 
 
583
  label=f"URL {i}",
584
- placeholder="https://..."
 
585
  )
586
- inputs_list.append(url)
 
587
 
588
  with gr.TabItem("πŸ“± Social Media"):
589
- for i in range(1, 4):
 
 
590
  with gr.Group():
591
- gr.Markdown(f"**Social Media {i}**")
592
- social_url = gr.Textbox(
593
- label="URL",
594
- placeholder="https://..."
 
595
  )
596
  with gr.Row():
597
- social_name = gr.Textbox(
598
- label="Account/Name",
599
- placeholder="Account or person name"
 
600
  )
601
- social_context = gr.Textbox(
602
- label="Context",
603
- placeholder="Relevant context"
 
604
  )
605
- inputs_list.extend([social_url, social_name, social_context])
 
 
 
 
606
 
607
  with gr.Row():
608
- generate_btn = gr.Button("✨ Generate News", variant="primary")
609
- reset_btn = gr.Button("πŸ”„ Clear All")
610
 
611
  with gr.Tabs():
612
- with gr.TabItem("πŸ“„ Generated News"):
613
  news_output = gr.Textbox(
614
- label="News Draft",
615
- lines=15,
616
- show_copy_button=True
 
617
  )
618
-
619
- with gr.TabItem("πŸŽ™οΈ Transcriptions"):
620
  transcriptions_output = gr.Textbox(
621
- label="Source Transcriptions",
622
- lines=10,
623
- show_copy_button=True
 
624
  )
625
 
626
- generate_btn.click(
 
 
 
 
 
627
  fn=generate_news,
628
- inputs=inputs_list,
629
- outputs=[news_output, transcriptions_output]
630
  )
631
-
632
- def reset_all():
633
- return [None]*len(inputs_list) + ["", ""]
634
-
635
- reset_btn.click(
636
- fn=reset_all,
637
- inputs=None,
638
- outputs=inputs_list + [news_output, transcriptions_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  )
640
 
 
 
 
641
  return demo
642
 
643
  if __name__ == "__main__":
644
- try:
645
- model_manager.initialize_whisper()
646
- except Exception as e:
647
- logger.warning(f"Initial whisper model loading failed: {str(e)}")
648
-
649
- demo = create_demo()
650
- demo.queue(max_size=5)
651
- demo.launch(
652
- share=True,
653
- server_name="0.0.0.0",
654
- server_port=7860
655
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import whisper
11
  import subprocess
12
  from pydub import AudioSegment
13
+ import fitz # PyMuPDF
14
  import docx
15
  import yt_dlp
16
  from functools import lru_cache
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
21
+ import traceback # For detailed error logging
22
 
23
  # Configure logging
24
  logging.basicConfig(
 
30
  # Login to Hugging Face Hub if token is available
31
  HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
32
  if HUGGINGFACE_TOKEN:
33
+ try:
34
+ login(token=HUGGINGFACE_TOKEN)
35
+ logger.info("Successfully logged in to Hugging Face Hub.")
36
+ except Exception as e:
37
+ logger.error(f"Failed to login to Hugging Face Hub: {e}")
38
 
39
  class ModelManager:
40
  _instance = None
41
+
42
  def __new__(cls):
43
  if cls._instance is None:
44
  cls._instance = super(ModelManager, cls).__new__(cls)
45
  cls._instance._initialized = False
46
  return cls._instance
47
+
48
  def __init__(self):
49
  if not self._initialized:
50
  self.tokenizer = None
51
  self.model = None
52
+ self.text_pipeline = None # Renamed for clarity
53
  self.whisper_model = None
54
  self._initialized = True
55
  self.last_used = time.time()
56
+ self.llm_loading = False
57
+ self.whisper_loading = False
58
+
59
+ @spaces.GPU(duration=120) # Increased duration for potentially long loads
60
  def initialize_llm(self):
61
+ """Initialize LLM model with standard transformers"""
62
+ if self.llm_loading:
63
+ logger.info("LLM initialization already in progress.")
64
+ return True # Assume it will succeed or fail elsewhere
65
+ if self.tokenizer and self.model and self.text_pipeline:
66
+ logger.info("LLM already initialized.")
67
+ self.last_used = time.time()
68
+ return True
69
+
70
+ self.llm_loading = True
71
  try:
72
+ # Use small model for ZeroGPU compatibility
73
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
74
+
75
+ logger.info("Loading LLM tokenizer...")
76
+ self.tokenizer = AutoTokenizer.from_pretrained(
77
+ MODEL_NAME,
78
+ token=HUGGINGFACE_TOKEN,
79
+ use_fast=True
80
+ )
81
+
82
+ if self.tokenizer.pad_token is None:
83
+ self.tokenizer.pad_token = self.tokenizer.eos_token
84
+
85
+ # Basic memory settings for ZeroGPU
86
+ logger.info("Loading LLM model...")
87
+ self.model = AutoModelForCausalLM.from_pretrained(
88
+ MODEL_NAME,
89
+ token=HUGGINGFACE_TOKEN,
90
+ device_map="auto",
91
+ torch_dtype=torch.float16,
92
+ low_cpu_mem_usage=True,
93
+ # Optimizations for ZeroGPU
94
+ # max_memory={0: "4GB"}, # Removed for better auto handling initially
95
+ offload_folder="offload",
96
+ offload_state_dict=True
97
  )
98
+
99
+ # Create text generation pipeline
100
+ logger.info("Creating LLM text generation pipeline...")
101
+ self.text_pipeline = pipeline(
102
+ "text-generation",
103
+ model=self.model,
104
+ tokenizer=self.tokenizer,
105
+ torch_dtype=torch.float16,
106
+ device_map="auto",
107
+ max_length=1024 # Default max length
 
 
 
108
  )
109
+
110
+ logger.info("LLM initialized successfully")
111
  self.last_used = time.time()
112
+ self.llm_loading = False
113
  return True
114
+
115
  except Exception as e:
116
  logger.error(f"Error initializing LLM: {str(e)}")
117
+ logger.error(traceback.format_exc()) # Log full traceback
118
+ # Reset partially loaded components
119
+ self.tokenizer = None
120
+ self.model = None
121
+ self.text_pipeline = None
122
+ if torch.cuda.is_available():
123
+ torch.cuda.empty_cache()
124
+ gc.collect()
125
+ self.llm_loading = False
126
+ raise # Re-raise the exception to signal failure
127
 
128
+ @spaces.GPU(duration=120) # Increased duration
129
  def initialize_whisper(self):
130
+ """Initialize Whisper model for audio transcription"""
131
+ if self.whisper_loading:
132
+ logger.info("Whisper initialization already in progress.")
133
+ return True
134
+ if self.whisper_model:
135
+ logger.info("Whisper already initialized.")
136
+ self.last_used = time.time()
137
+ return True
138
+
139
+ self.whisper_loading = True
140
  try:
141
  logger.info("Loading Whisper model...")
142
+ # Using tiny model for efficiency but can be changed based on needs
143
+ # Specify weights_only=True to address the FutureWarning
144
+ # Note: Whisper's load_model might not directly support weights_only yet.
145
+ # If it errors, remove the weights_only=True. The warning is mainly informative.
146
+ # Let's attempt without weights_only first as whisper might handle it internally
147
  self.whisper_model = whisper.load_model(
148
+ "tiny", # Consider "base" for better accuracy if "tiny" struggles
149
  device="cuda" if torch.cuda.is_available() else "cpu",
150
+ download_root="/tmp/whisper" # Use persistent storage if available/needed
 
151
  )
152
  logger.info("Whisper model initialized successfully")
153
  self.last_used = time.time()
154
+ self.whisper_loading = False
155
  return True
156
  except Exception as e:
157
  logger.error(f"Error initializing Whisper: {str(e)}")
158
+ logger.error(traceback.format_exc())
159
+ self.whisper_model = None
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
+ gc.collect()
163
+ self.whisper_loading = False
164
  raise
165
 
166
  def check_llm_initialized(self):
167
  """Check if LLM is initialized and initialize if needed"""
168
+ if self.tokenizer is None or self.model is None or self.text_pipeline is None:
169
  logger.info("LLM not initialized, initializing...")
170
+ if not self.llm_loading: # Prevent re-entry if already loading
171
+ self.initialize_llm()
172
+ else:
173
+ logger.info("LLM initialization is already in progress by another request.")
174
+ # Optional: Wait a bit for the other process to finish
175
+ time.sleep(5)
176
+ if self.tokenizer is None or self.model is None or self.text_pipeline is None:
177
+ raise RuntimeError("LLM initialization timed out or failed.")
178
  self.last_used = time.time()
179
+
180
  def check_whisper_initialized(self):
181
  """Check if Whisper model is initialized and initialize if needed"""
182
  if self.whisper_model is None:
183
  logger.info("Whisper model not initialized, initializing...")
184
+ if not self.whisper_loading: # Prevent re-entry
185
+ self.initialize_whisper()
186
+ else:
187
+ logger.info("Whisper initialization is already in progress by another request.")
188
+ time.sleep(5)
189
+ if self.whisper_model is None:
190
+ raise RuntimeError("Whisper initialization timed out or failed.")
191
  self.last_used = time.time()
192
+
193
  def reset_models(self, force=False):
194
  """Reset models to free memory if they haven't been used recently"""
195
  current_time = time.time()
196
+ # Only reset if forced or models haven't been used for 10 minutes (600 seconds)
197
+ if force or (current_time - self.last_used > 600):
198
  try:
199
  logger.info("Resetting models to free memory...")
200
+
201
+ # Check and delete attributes safely
202
+ if hasattr(self, 'model') and self.model is not None:
203
  del self.model
204
+ self.model = None
205
+ logger.info("LLM model deleted.")
206
+
207
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
208
  del self.tokenizer
209
+ self.tokenizer = None
210
+ logger.info("LLM tokenizer deleted.")
211
+
212
+ if hasattr(self, 'text_pipeline') and self.text_pipeline is not None:
213
+ del self.text_pipeline
214
+ self.text_pipeline = None
215
+ logger.info("LLM pipeline deleted.")
216
+
217
+ if hasattr(self, 'whisper_model') and self.whisper_model is not None:
218
  del self.whisper_model
219
+ self.whisper_model = None
220
+ logger.info("Whisper model deleted.")
221
+
222
+ # Explicitly clear CUDA cache and collect garbage
 
223
  if torch.cuda.is_available():
224
  torch.cuda.empty_cache()
225
+ # torch.cuda.synchronize() # May not be needed and can slow down
226
+ logger.info("CUDA cache cleared.")
227
+
228
  gc.collect()
229
+ logger.info("Garbage collected. Models reset successfully.")
230
+ self._initialized = False # Mark as uninitialized so they reload on next use
231
+
232
  except Exception as e:
233
  logger.error(f"Error resetting models: {str(e)}")
234
+ logger.error(traceback.format_exc())
235
 
236
+ # Create global model manager instance
237
  model_manager = ModelManager()
238
 
239
+ @lru_cache(maxsize=16) # Reduced cache size slightly
240
  def download_social_media_video(url):
241
+ """Download audio from a social media video URL."""
242
+ temp_dir = tempfile.mkdtemp()
243
+ output_template = os.path.join(temp_dir, '%(id)s.%(ext)s')
244
+
245
  ydl_opts = {
246
  'format': 'bestaudio/best',
247
  'postprocessors': [{
248
  'key': 'FFmpegExtractAudio',
249
  'preferredcodec': 'mp3',
250
+ 'preferredquality': '192', # Standard quality
251
  }],
252
+ 'outtmpl': output_template,
253
+ 'quiet': True,
254
+ 'no_warnings': True,
255
+ 'nocheckcertificate': True, # Sometimes needed for tricky sites
256
+ 'retries': 3, # Add retries
257
+ 'socket_timeout': 15, # Timeout
258
  }
259
  try:
260
+ logger.info(f"Attempting to download audio from: {url}")
261
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
262
  info_dict = ydl.extract_info(url, download=True)
263
+ # Construct the expected final filename after postprocessing
264
+ audio_file = os.path.join(temp_dir, f"{info_dict['id']}.mp3")
265
+ if not os.path.exists(audio_file):
266
+ # Fallback if filename doesn't match exactly (e.g., webm -> mp3)
267
+ found_files = [f for f in os.listdir(temp_dir) if f.endswith('.mp3')]
268
+ if found_files:
269
+ audio_file = os.path.join(temp_dir, found_files[0])
270
+ else:
271
+ raise FileNotFoundError(f"Could not find downloaded MP3 in {temp_dir}")
272
+
273
+ logger.info(f"Audio downloaded successfully: {audio_file}")
274
+ # Read the file content to return, as the temp dir might be cleaned up
275
+ with open(audio_file, 'rb') as f:
276
+ audio_content = f.read()
277
+
278
+ # Clean up the temporary directory and file
279
+ try:
280
+ os.remove(audio_file)
281
+ os.rmdir(temp_dir)
282
+ except OSError as e:
283
+ logger.warning(f"Could not completely clean up temp download files: {e}")
284
+
285
+ # Save the content to a new temporary file that Gradio can handle
286
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_output_file:
287
+ temp_output_file.write(audio_content)
288
+ final_path = temp_output_file.name
289
+ logger.info(f"Audio saved to temporary file: {final_path}")
290
+ return final_path
291
+
292
+ except yt_dlp.utils.DownloadError as e:
293
+ logger.error(f"yt-dlp download error for {url}: {str(e)}")
294
+ # Clean up temp dir on error
295
+ try:
296
+ if os.path.exists(temp_dir):
297
+ import shutil
298
+ shutil.rmtree(temp_dir)
299
+ except Exception as cleanup_e:
300
+ logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
301
+ return None # Return None to indicate failure
302
  except Exception as e:
303
+ logger.error(f"Unexpected error downloading video from {url}: {str(e)}")
304
+ logger.error(traceback.format_exc())
305
+ # Clean up temp dir on error
306
+ try:
307
+ if os.path.exists(temp_dir):
308
+ import shutil
309
+ shutil.rmtree(temp_dir)
310
+ except Exception as cleanup_e:
311
+ logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
312
+ return None # Return None
313
 
314
+ def convert_video_to_audio(video_file_path):
315
  """Convert a video file to audio using ffmpeg directly."""
316
  try:
317
+ # Create a temporary file path for the output MP3
318
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
319
+ output_file_path = temp_file.name
320
+
321
+ logger.info(f"Converting video '{video_file_path}' to audio '{output_file_path}'")
322
+
323
+ # Use ffmpeg directly via subprocess
324
  command = [
325
+ "ffmpeg",
326
+ "-i", video_file_path,
327
+ "-vn", # No video
328
+ "-acodec", "libmp3lame", # Specify MP3 codec
329
+ "-ab", "192k", # Audio bitrate
330
+ "-ar", "44100", # Audio sample rate
331
+ "-ac", "2", # Stereo audio
332
+ output_file_path,
333
+ "-y", # Overwrite output file if it exists
334
+ "-loglevel", "error" # Suppress verbose ffmpeg output
335
  ]
336
+
337
+ process = subprocess.run(command, check=True, capture_output=True, text=True)
338
+ logger.info(f"ffmpeg conversion successful for {video_file_path}.")
339
+ logger.debug(f"ffmpeg stdout: {process.stdout}")
340
+ logger.debug(f"ffmpeg stderr: {process.stderr}")
341
+
342
+
343
+ # Verify output file exists and has size
344
+ if not os.path.exists(output_file_path) or os.path.getsize(output_file_path) == 0:
345
+ raise RuntimeError(f"ffmpeg conversion failed: Output file '{output_file_path}' not created or is empty.")
346
+
347
+ logger.info(f"Video converted to audio: {output_file_path}")
348
+ return output_file_path
349
+ except subprocess.CalledProcessError as e:
350
+ logger.error(f"ffmpeg command failed with exit code {e.returncode}")
351
+ logger.error(f"ffmpeg stderr: {e.stderr}")
352
+ logger.error(f"ffmpeg stdout: {e.stdout}")
353
+ # Clean up potentially empty output file
354
+ if os.path.exists(output_file_path):
355
+ os.remove(output_file_path)
356
+ raise RuntimeError(f"ffmpeg conversion failed: {e.stderr}") from e
357
  except Exception as e:
358
+ logger.error(f"Error converting video '{video_file_path}': {str(e)}")
359
+ logger.error(traceback.format_exc())
360
+ # Clean up potentially created output file
361
+ if 'output_file_path' in locals() and os.path.exists(output_file_path):
362
+ os.remove(output_file_path)
363
+ raise # Re-raise the exception
364
 
365
+ def preprocess_audio(input_audio_path):
366
+ """Preprocess the audio file (e.g., normalize volume)."""
367
  try:
368
+ logger.info(f"Preprocessing audio file: {input_audio_path}")
369
+ audio = AudioSegment.from_file(input_audio_path)
370
+
371
+ # Apply normalization (optional, adjust target dBFS as needed)
372
+ # Target loudness: -20 dBFS. Adjust gain based on current loudness.
373
+ # change_in_dBFS = -20.0 - audio.dBFS
374
+ # audio = audio.apply_gain(change_in_dBFS)
375
+
376
+ # Export to a new temporary file
377
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
378
+ output_path = temp_file.name
379
+ audio.export(output_path, format="mp3")
380
+
381
+ logger.info(f"Audio preprocessed and saved to: {output_path}")
382
+ return output_path
383
  except Exception as e:
384
+ logger.error(f"Error preprocessing audio '{input_audio_path}': {str(e)}")
385
+ logger.error(traceback.format_exc())
386
+ # Return original path if preprocessing fails? Or raise error?
387
+ # Let's raise the error to signal failure clearly.
388
  raise
389
 
390
+ @spaces.GPU(duration=300) # Allow more time for transcription
391
+ def transcribe_audio_or_video(file_input):
392
+ """Transcribe an audio or video file (local path or Gradio File object)."""
393
+ audio_file_to_transcribe = None
394
+ original_input_path = None
395
+ temp_files_to_clean = []
396
+
397
  try:
398
  model_manager.check_whisper_initialized()
399
+
400
+ if file_input is None:
401
+ logger.info("No file input provided for transcription.")
402
+ return "" # Return empty string for None input
403
+
404
+ # Determine input type and get file path
405
+ if isinstance(file_input, str): # Input is a path
406
+ original_input_path = file_input
407
+ logger.info(f"Processing path input: {original_input_path}")
408
+ if not os.path.exists(original_input_path):
409
+ logger.error(f"Input file path does not exist: {original_input_path}")
410
+ raise FileNotFoundError(f"Input file not found: {original_input_path}")
411
+ input_path = original_input_path
412
+ elif hasattr(file_input, 'name'): # Input is a Gradio File object
413
+ original_input_path = file_input.name
414
+ logger.info(f"Processing Gradio file input: {original_input_path}")
415
+ input_path = original_input_path # Gradio usually provides a temp path
416
  else:
417
+ logger.error(f"Unsupported input type for transcription: {type(file_input)}")
418
+ raise TypeError("Invalid input type for transcription. Expected file path or Gradio File object.")
419
+
420
+ file_extension = os.path.splitext(input_path)[1].lower()
421
+
422
+ # Check if it's a video file that needs conversion
423
+ if file_extension in ['.mp4', '.avi', '.mov', '.mkv', '.webm']:
424
+ logger.info(f"Detected video file ({file_extension}), converting to audio...")
425
+ converted_audio_path = convert_video_to_audio(input_path)
426
+ temp_files_to_clean.append(converted_audio_path)
427
+ audio_file_to_process = converted_audio_path
428
+ elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
429
+ logger.info(f"Detected audio file ({file_extension}).")
430
+ audio_file_to_process = input_path
431
+ else:
432
+ logger.error(f"Unsupported file extension for transcription: {file_extension}")
433
+ raise ValueError(f"Unsupported file type: {file_extension}")
434
+
435
+ # Preprocess the audio (optional, could be skipped if causing issues)
436
  try:
437
+ preprocessed_audio_path = preprocess_audio(audio_file_to_process)
438
+ # If preprocessing creates a new file different from the input, add it to cleanup
439
+ if preprocessed_audio_path != audio_file_to_process:
440
+ temp_files_to_clean.append(preprocessed_audio_path)
441
+ audio_file_to_transcribe = preprocessed_audio_path
442
+ except Exception as preprocess_err:
443
+ logger.warning(f"Audio preprocessing failed: {preprocess_err}. Using original/converted audio.")
444
+ audio_file_to_transcribe = audio_file_to_process # Fallback
445
+
446
+ logger.info(f"Transcribing audio file: {audio_file_to_transcribe}")
447
+ if not os.path.exists(audio_file_to_transcribe):
448
+ raise FileNotFoundError(f"Audio file to transcribe not found: {audio_file_to_transcribe}")
449
+
450
+ # Perform transcription
451
+ with torch.inference_mode(): # Ensure inference mode for efficiency
452
+ # Use fp16 if available on CUDA
453
+ use_fp16 = torch.cuda.is_available()
454
+ result = model_manager.whisper_model.transcribe(
455
+ audio_file_to_transcribe,
456
+ fp16=use_fp16
457
+ )
458
+ if not result:
459
+ raise RuntimeError("Transcription failed to produce results")
460
+
461
+ transcription = result.get("text", "Error: Transcription result empty")
462
+ # Limit transcription length shown in logs
463
+ log_transcription = (transcription[:100] + '...') if len(transcription) > 100 else transcription
464
+ logger.info(f"Transcription completed: {log_transcription}")
465
+
466
  return transcription
467
+
468
+ except FileNotFoundError as e:
469
+ logger.error(f"File not found error during transcription: {e}")
470
+ return f"Error: Input file not found ({e})"
471
+ except ValueError as e:
472
+ logger.error(f"Value error during transcription: {e}")
473
+ return f"Error: Unsupported file type ({e})"
474
+ except TypeError as e:
475
+ logger.error(f"Type error during transcription setup: {e}")
476
+ return f"Error: Invalid input provided ({e})"
477
+ except RuntimeError as e:
478
+ logger.error(f"Runtime error during transcription: {e}")
479
+ logger.error(traceback.format_exc())
480
+ return f"Error during processing: {e}"
481
  except Exception as e:
482
+ logger.error(f"Unexpected error during transcription: {str(e)}")
483
+ logger.error(traceback.format_exc())
484
+ return f"Error processing the file: An unexpected error occurred."
485
+
486
+ finally:
487
+ # Clean up all temporary files created during the process
488
+ for temp_file in temp_files_to_clean:
489
+ try:
490
+ if os.path.exists(temp_file):
491
+ os.remove(temp_file)
492
+ logger.info(f"Cleaned up temporary file: {temp_file}")
493
+ except Exception as e:
494
+ logger.warning(f"Could not remove temporary file {temp_file}: {str(e)}")
495
+ # Optionally reset models if idle (might be too aggressive here)
496
+ # model_manager.reset_models()
497
 
498
+ @lru_cache(maxsize=16)
499
  def read_document(document_path):
500
+ """Read the content of a document (PDF, DOCX, XLSX, CSV)."""
501
  try:
502
+ logger.info(f"Reading document: {document_path}")
503
+ if not os.path.exists(document_path):
504
+ raise FileNotFoundError(f"Document not found: {document_path}")
505
+
506
+ file_extension = os.path.splitext(document_path)[1].lower()
507
+
508
+ if file_extension == ".pdf":
509
  doc = fitz.open(document_path)
510
+ text = "\n".join([page.get_text() for page in doc])
511
+ doc.close()
512
+ return text
513
+ elif file_extension == ".docx":
514
  doc = docx.Document(document_path)
515
  return "\n".join([paragraph.text for paragraph in doc.paragraphs])
516
+ elif file_extension in (".xlsx", ".xls"):
517
+ # Read all sheets and combine
518
+ xls = pd.ExcelFile(document_path)
519
+ text = ""
520
+ for sheet_name in xls.sheet_names:
521
+ df = pd.read_excel(xls, sheet_name=sheet_name)
522
+ text += f"--- Sheet: {sheet_name} ---\n{df.to_string()}\n\n"
523
+ return text.strip()
524
+ elif file_extension == ".csv":
525
+ # Try detecting separator
526
+ try:
527
+ df = pd.read_csv(document_path)
528
+ except pd.errors.ParserError:
529
+ logger.warning(f"Could not parse CSV {document_path} with default comma separator, trying semicolon.")
530
+ df = pd.read_csv(document_path, sep=';')
531
+ return df.to_string()
532
  else:
533
+ logger.warning(f"Unsupported document type: {file_extension}")
534
  return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
535
+ except FileNotFoundError as e:
536
+ logger.error(f"Error reading document: {e}")
537
+ return f"Error: Document file not found at {document_path}"
538
  except Exception as e:
539
+ logger.error(f"Error reading document {document_path}: {str(e)}")
540
+ logger.error(traceback.format_exc())
541
  return f"Error reading document: {str(e)}"
542
 
543
+ @lru_cache(maxsize=16)
544
  def read_url(url):
545
+ """Read the main textual content of a URL."""
546
+ if not url or not url.strip().startswith('http'):
547
+ logger.info(f"Invalid or empty URL provided: '{url}'")
548
+ return "" # Return empty for invalid or empty URLs
549
+
550
  try:
551
+ logger.info(f"Reading URL: {url}")
552
  headers = {
553
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
554
  }
555
+ # Increased timeout
556
+ response = requests.get(url, headers=headers, timeout=20, allow_redirects=True)
557
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
558
+
559
+ # Check content type - proceed only if likely HTML/text
560
+ content_type = response.headers.get('content-type', '').lower()
561
+ if not ('html' in content_type or 'text' in content_type):
562
+ logger.warning(f"URL {url} has non-text content type: {content_type}. Skipping.")
563
+ return f"Error: URL content type ({content_type}) is not text/html."
564
+
565
  soup = BeautifulSoup(response.content, 'html.parser')
566
+
567
+ # Remove non-content elements like scripts, styles, nav, footers etc.
568
+ for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav", "aside", "form", "button"]):
569
  element.extract()
570
+
571
+ # Attempt to find main content area (common tags/attributes)
572
+ main_content = (
573
+ soup.find("main") or
574
+ soup.find("article") or
575
+ soup.find("div", class_=["content", "main", "post-content", "entry-content", "article-body"]) or
576
+ soup.find("div", id=["content", "main", "article"])
577
+ )
578
+
579
  if main_content:
580
  text = main_content.get_text(separator='\n', strip=True)
581
  else:
582
+ # Fallback to body if no specific main content found
583
+ body = soup.find("body")
584
+ if body:
585
+ text = body.get_text(separator='\n', strip=True)
586
+ else: # Very basic fallback
587
+ text = soup.get_text(separator='\n', strip=True)
588
+
589
+ # Clean up whitespace: replace multiple newlines/spaces with single ones
590
+ text = '\n'.join([line.strip() for line in text.split('\n') if line.strip()])
591
+ text = ' '.join(text.split()) # Consolidate spaces within lines
592
+
593
+ if not text:
594
+ logger.warning(f"Could not extract meaningful text from URL: {url}")
595
+ return "Error: Could not extract text content from URL."
596
+
597
+ # Limit content size to avoid overwhelming the LLM
598
+ max_chars = 15000
599
+ if len(text) > max_chars:
600
+ logger.info(f"URL content truncated to {max_chars} characters.")
601
+ text = text[:max_chars] + "... [content truncated]"
602
+
603
+ return text
604
+ except requests.exceptions.RequestException as e:
605
+ logger.error(f"Error fetching URL {url}: {str(e)}")
606
+ return f"Error reading URL: Could not fetch content ({e})"
607
  except Exception as e:
608
+ logger.error(f"Error parsing URL {url}: {str(e)}")
609
+ logger.error(traceback.format_exc())
610
+ return f"Error reading URL: Could not parse content ({e})"
611
 
612
+ def process_social_media_url(url):
613
+ """Process a social media URL, attempting to get text and transcribe video/audio."""
614
+ if not url or not url.strip().startswith('http'):
615
+ logger.info(f"Invalid or empty social media URL: '{url}'")
616
  return None
617
+
618
+ logger.info(f"Processing social media URL: {url}")
619
+ text_content = None
620
+ video_transcription = None
621
+ error_occurred = False
622
+
623
+ # 1. Try extracting text content using read_url (might work for some platforms/posts)
624
  try:
625
  text_content = read_url(url)
626
+ if text_content and text_content.startswith("Error:"):
627
+ logger.warning(f"Failed to read text content from social URL {url}: {text_content}")
628
+ text_content = None # Reset if it was an error message
629
+ except Exception as e:
630
+ logger.error(f"Error reading text content from social URL {url}: {e}")
631
+ error_occurred = True
632
 
633
+ # 2. Try downloading and transcribing potential video/audio content
634
+ downloaded_audio_path = None
635
+ try:
636
+ downloaded_audio_path = download_social_media_video(url)
637
+ if downloaded_audio_path:
638
+ logger.info(f"Audio downloaded from {url}, proceeding to transcription.")
639
+ video_transcription = transcribe_audio_or_video(downloaded_audio_path)
640
+ if video_transcription and video_transcription.startswith("Error"):
641
+ logger.warning(f"Transcription failed for audio from {url}: {video_transcription}")
642
+ video_transcription = None # Reset if it was an error
643
+ else:
644
+ logger.info(f"No downloadable audio/video found or download failed for URL: {url}")
645
+ except Exception as e:
646
+ logger.error(f"Error processing video content from social URL {url}: {e}")
647
+ logger.error(traceback.format_exc())
648
+ error_occurred = True
649
+ finally:
650
+ # Clean up downloaded file if it exists
651
+ if downloaded_audio_path and os.path.exists(downloaded_audio_path):
652
+ try:
653
+ os.remove(downloaded_audio_path)
654
+ logger.info(f"Cleaned up downloaded audio: {downloaded_audio_path}")
655
+ except Exception as e:
656
+ logger.warning(f"Failed to cleanup downloaded audio {downloaded_audio_path}: {e}")
657
+
658
+ # Return results only if some content was found or no critical error occurred
659
+ if text_content or video_transcription or not error_occurred:
660
  return {
661
+ "text": text_content or "", # Ensure string type
662
+ "video": video_transcription or "" # Ensure string type
663
  }
664
+ else:
665
+ logger.error(f"Failed to process social media URL {url} completely.")
666
+ return None # Indicate failure
667
+
668
 
669
+ @spaces.GPU(duration=300) # Allow more time for generation
670
  def generate_news(instructions, facts, size, tone, *args):
671
+ """Generate a news article based on provided data using an LLM."""
672
+ request_start_time = time.time()
673
+ logger.info("Received request to generate news.")
674
  try:
675
+ # Ensure size is integer
676
+ try:
677
+ size = int(size) if size else 250 # Default size if None/empty
678
+ except ValueError:
679
+ logger.warning(f"Invalid size value '{size}', defaulting to 250.")
680
  size = 250
681
+
682
+ # Check if models are initialized, load if necessary
683
+ model_manager.check_llm_initialized() # LLM is essential
684
+ # Whisper might be needed later, check/load if audio sources exist
685
+
686
+ # --- Argument Parsing ---
687
+ # The order *must* match the order components are added to inputs_list in create_demo
688
+ # Fixed inputs: instructions, facts, size, tone (already passed directly)
689
+ # Dynamic inputs from *args:
690
+ # Expected order in *args based on create_demo:
691
+ # 5 Documents, 15 Audio-related, 5 URLs, 9 Social-related
692
+ num_docs = 5
693
+ num_audio_sources = 5
694
+ num_audio_inputs_per_source = 3
695
+ num_urls = 5
696
+ num_social_sources = 3
697
+ num_social_inputs_per_source = 3
698
+
699
+ total_expected_args = num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls + (num_social_sources * num_social_inputs_per_source)
700
+
701
+ args_list = list(args)
702
+ # Pad args_list with None if fewer arguments were received than expected
703
+ args_list.extend([None] * (total_expected_args - len(args_list)))
704
+
705
+ # Slice arguments based on the expected order
706
+ doc_files = args_list[0:num_docs]
707
+ audio_inputs_flat = args_list[num_docs : num_docs + (num_audio_sources * num_audio_inputs_per_source)]
708
+ url_inputs = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) : num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls]
709
+ social_inputs_flat = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls : total_expected_args]
710
+
711
  knowledge_base = {
712
+ "instructions": instructions or "No specific instructions provided.",
713
+ "facts": facts or "No specific facts provided.",
714
  "document_content": [],
715
  "audio_data": [],
716
  "url_content": [],
717
  "social_content": []
718
  }
719
+ raw_transcriptions = "" # Initialize transcription log
720
 
721
+ # --- Process Inputs ---
722
+ logger.info("Processing document inputs...")
723
+ for i, doc_file in enumerate(doc_files):
724
+ if doc_file and hasattr(doc_file, 'name'):
725
+ try:
726
+ content = read_document(doc_file.name) # doc_file.name is the temp path
727
+ if content and not content.startswith("Error"):
728
+ # Truncate long documents for the knowledge base summary
729
+ doc_excerpt = (content[:1000] + "... [document truncated]") if len(content) > 1000 else content
730
+ knowledge_base["document_content"].append(f"[Document {i+1} Source: {os.path.basename(doc_file.name)}]\n{doc_excerpt}")
731
+ else:
732
+ logger.warning(f"Skipping document {i+1} due to read error or empty content: {content}")
733
+ except Exception as e:
734
+ logger.error(f"Failed to process document {i+1} ({doc_file.name}): {e}")
735
+ # No cleanup needed here, Gradio handles temp file uploads
736
+
737
+ logger.info("Processing URL inputs...")
738
+ for i, url in enumerate(url_inputs):
739
+ if url and isinstance(url, str) and url.strip().startswith('http'):
740
+ try:
741
+ content = read_url(url)
742
+ if content and not content.startswith("Error"):
743
+ # Content is already truncated in read_url if needed
744
+ knowledge_base["url_content"].append(f"[URL {i+1} Source: {url}]\n{content}")
745
+ else:
746
+ logger.warning(f"Skipping URL {i+1} ({url}) due to read error or empty content: {content}")
747
+ except Exception as e:
748
+ logger.error(f"Failed to process URL {i+1} ({url}): {e}")
749
 
750
+ logger.info("Processing audio/video inputs...")
751
+ has_audio_source = False
752
+ for i in range(num_audio_sources):
753
+ start_idx = i * num_audio_inputs_per_source
754
+ audio_file = audio_inputs_flat[start_idx]
755
+ name = audio_inputs_flat[start_idx + 1] or f"Source {i+1}"
756
+ position = audio_inputs_flat[start_idx + 2] or "N/A"
757
+
758
+ if audio_file and hasattr(audio_file, 'name'):
759
+ # Store info for transcription later
760
+ knowledge_base["audio_data"].append({
761
+ "file_path": audio_file.name, # Use the temp path
762
+ "name": name,
763
+ "position": position,
764
+ "original_filename": os.path.basename(audio_file.name) # Keep original for logs
765
+ })
766
+ has_audio_source = True
767
+ logger.info(f"Added audio source {i+1}: {name} ({position}) - File: {knowledge_base['audio_data'][-1]['original_filename']}")
768
+
769
+ logger.info("Processing social media inputs...")
770
+ has_social_source = False
771
+ for i in range(num_social_sources):
772
+ start_idx = i * num_social_inputs_per_source
773
+ social_url = social_inputs_flat[start_idx]
774
+ social_name = social_inputs_flat[start_idx + 1] or f"Social Source {i+1}"
775
+ social_context = social_inputs_flat[start_idx + 2] or "N/A"
776
+
777
+ if social_url and isinstance(social_url, str) and social_url.strip().startswith('http'):
778
+ try:
779
+ logger.info(f"Processing social media URL {i+1}: {social_url}")
780
+ social_data = process_social_media_url(social_url)
781
+ if social_data:
782
+ knowledge_base["social_content"].append({
 
 
 
 
 
 
 
 
 
 
783
  "url": social_url,
784
+ "name": social_name,
785
+ "context": social_context,
786
+ "text": social_data.get("text", ""),
787
+ "video_transcription": social_data.get("video", "") # Store potential transcription
788
  })
789
+ has_social_source = True
790
+ logger.info(f"Added social source {i+1}: {social_name} ({social_context}) from {social_url}")
791
+ else:
792
+ logger.warning(f"Could not retrieve any content for social URL {i+1}: {social_url}")
793
+ except Exception as e:
794
+ logger.error(f"Failed to process social URL {i+1} ({social_url}): {e}")
795
 
 
 
796
 
797
+ # --- Transcribe Audio/Video ---
798
+ # Only initialize Whisper if needed
799
+ transcriptions_for_prompt = ""
800
+ if has_audio_source or any(sc.get("video_transcription") == "[NEEDS_TRANSCRIPTION]" for sc in knowledge_base["social_content"]): # Check if transcription actually needed
801
+ logger.info("Audio sources detected, ensuring Whisper model is ready...")
802
+ try:
803
+ model_manager.check_whisper_initialized()
804
+ except Exception as whisper_init_err:
805
+ logger.error(f"FATAL: Whisper model initialization failed: {whisper_init_err}. Cannot transcribe.")
806
+ # Add error message to raw transcriptions and continue without transcriptions
807
+ raw_transcriptions += f"[ERROR] Whisper model failed to load. Audio sources could not be transcribed: {whisper_init_err}\n\n"
808
+ # Optionally return an error message immediately?
809
+ # return f"Error: Could not initialize transcription model. {whisper_init_err}", raw_transcriptions
810
+
811
+ if model_manager.whisper_model: # Proceed only if whisper loaded successfully
812
+ logger.info("Transcribing collected audio sources...")
813
+ for idx, data in enumerate(knowledge_base["audio_data"]):
814
+ try:
815
+ logger.info(f"Transcribing audio source {idx+1}: {data['original_filename']} ({data['name']}, {data['position']})")
816
+ transcription = transcribe_audio_or_video(data["file_path"])
817
+ if transcription and not transcription.startswith("Error"):
818
+ quote = f'"{transcription}" - {data["name"]}, {data["position"]}'
819
+ transcriptions_for_prompt += f"{quote}\n\n"
820
+ raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n"{transcription}"\n\n'
821
+ else:
822
+ logger.warning(f"Transcription failed or returned error for audio source {idx+1}: {transcription}")
823
+ raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {transcription}]\n\n'
824
+ except Exception as e:
825
+ logger.error(f"Error during transcription for audio source {idx+1} ({data['original_filename']}): {e}")
826
+ logger.error(traceback.format_exc())
827
+ raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {e}]\n\n'
828
+ # Gradio handles cleanup of the uploaded temp file audio_file.name
829
 
830
+ logger.info("Adding social media content to prompt data...")
831
  for idx, data in enumerate(knowledge_base["social_content"]):
832
+ source_id = f'[Social Media {idx+1}: {data["url"]} ({data["name"]}, {data["context"]})]'
833
+ has_content = False
834
+ if data["text"] and not data["text"].startswith("Error"):
835
+ # Truncate long text for the prompt, but keep full in knowledge base maybe?
836
+ text_excerpt = (data["text"][:500] + "...[text truncated]") if len(data["text"]) > 500 else data["text"]
837
+ social_text_prompt = f'{source_id} - Text Content:\n"{text_excerpt}"\n\n'
838
+ transcriptions_for_prompt += social_text_prompt # Add text content as if it were a quote/source
839
+ raw_transcriptions += f"{source_id}\nText Content:\n{data['text']}\n\n" # Log full text
840
+ has_content = True
841
+ if data["video_transcription"] and not data["video_transcription"].startswith("Error"):
842
+ social_video_prompt = f'{source_id} - Video Transcription:\n"{data["video_transcription"]}"\n\n'
843
+ transcriptions_for_prompt += social_video_prompt
844
+ raw_transcriptions += f"{source_id}\nVideo Transcription:\n{data['video_transcription']}\n\n"
845
+ has_content = True
846
+ if not has_content:
847
+ raw_transcriptions += f"{source_id}\n[No usable text or video transcription found]\n\n"
848
+
849
+
850
+ # --- Prepare Final Prompt ---
851
+ # Combine document and URL summaries
852
+ document_summary = "\n\n".join(knowledge_base["document_content"]) if knowledge_base["document_content"] else "No document content provided."
853
+ url_summary = "\n\n".join(knowledge_base["url_content"]) if knowledge_base["url_content"] else "No URL content provided."
854
+ transcription_summary = transcriptions_for_prompt if transcriptions_for_prompt else "No usable transcriptions available."
 
 
 
 
 
 
 
855
 
856
+ # Construct the prompt for the LLM
857
+ prompt = f"""<s>[INST] You are a professional news writer. Your task is to synthesize information from various sources into a coherent news article.
858
 
859
+ Primary Instructions: {knowledge_base["instructions"]}
860
+ Key Facts to Include: {knowledge_base["facts"]}
861
 
862
+ Supporting Information:
863
 
864
+ Document Content Summary:
865
+ {document_summary}
866
 
867
+ Web Content Summary (from URLs):
868
+ {url_summary}
869
 
870
+ Transcribed Quotes/Content (Use these directly or indirectly):
871
+ {transcription_summary}
872
 
873
+ Article Requirements:
874
+ - Title: Create a concise and informative title for the article.
875
+ - Hook: Write a compelling 15-word (approx.) hook sentence that complements the title.
876
+ - Body: Write the main news article body, aiming for approximately {size} words.
877
+ - Tone: Adopt a {tone} tone throughout the article.
878
+ - 5 Ws: Ensure the first paragraph addresses the core questions (Who, What, When, Where, Why).
879
+ - Quotes: Incorporate relevant information from the 'Transcribed Quotes/Content' section. Aim to use quotes where appropriate, but synthesize information rather than just listing quotes. Use quotation marks (" ") for direct quotes attributed correctly (e.g., based on name/position provided).
880
+ - Style: Adhere to a professional journalistic style. Be objective and factual.
881
+ - Accuracy: Do NOT invent information. Stick strictly to the provided facts, instructions, and source materials. If information is contradictory or missing, state that or omit the detail.
882
+ - Structure: Organize the article logically with clear paragraphs.
883
+
884
+ Begin the article now. [/INST]
885
+ Article Draft:
886
+ """
887
+
888
+ # Log the prompt length (useful for debugging context limits)
889
+ logger.info(f"Generated prompt length: {len(prompt.split())} words / {len(prompt)} characters.")
890
+ # Avoid logging the full prompt if it's too long or contains sensitive info
891
+ # logger.debug(f"Generated Prompt:\n{prompt}")
892
+
893
+ # --- Generate News Article ---
894
+ logger.info("Generating news article with LLM...")
895
+ generation_start_time = time.time()
896
+
897
+ # Estimate max_new_tokens based on requested size + buffer
898
+ # Add buffer for title, hook, and potential verbosity
899
+ estimated_tokens_per_word = 1.5
900
+ max_new_tokens = int(size * estimated_tokens_per_word + 150) # size words + buffer
901
+ # Ensure max_new_tokens doesn't exceed model limits (adjust based on model's max context)
902
+ model_max_length = 2048 # Typical for TinyLlama, but check specific model card
903
+ # Calculate available space for generation
904
+ # Note: This token count is approximate. Precise tokenization is needed for accuracy.
905
+ # prompt_tokens = len(model_manager.tokenizer.encode(prompt)) # More accurate but slower
906
+ prompt_tokens_estimate = len(prompt) // 3 # Rough estimate
907
+ max_new_tokens = min(max_new_tokens, model_max_length - prompt_tokens_estimate - 50) # Leave buffer
908
+ max_new_tokens = max(max_new_tokens, 100) # Ensure at least a minimum generation length
909
+
910
+ logger.info(f"Requesting max_new_tokens: {max_new_tokens}")
911
 
912
  try:
913
+ # Generate using the pipeline
914
+ outputs = model_manager.text_pipeline(
 
 
 
915
  prompt,
916
+ max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
917
+ do_sample=True,
918
+ temperature=0.7, # Standard temperature for creative but factual
919
+ top_p=0.95,
920
+ top_k=50, # Consider adding top_k
921
+ repetition_penalty=1.15, # Adjusted penalty
922
+ pad_token_id=model_manager.tokenizer.eos_token_id,
923
+ num_return_sequences=1
 
 
 
 
924
  )
925
+
926
+ # Extract generated text
927
+ generated_text = outputs[0]['generated_text']
928
+
929
+ # Clean up the result by removing the prompt
930
+ # Find the end of the prompt marker [/INST] and take text after it
931
+ inst_marker = "[/INST]"
932
+ marker_pos = generated_text.find(inst_marker)
933
+ if marker_pos != -1:
934
+ news_article = generated_text[marker_pos + len(inst_marker):].strip()
935
+ # Further clean potentially leading "Article Draft:" if model included it
936
+ if news_article.startswith("Article Draft:"):
937
+ news_article = news_article[len("Article Draft:"):].strip()
938
  else:
939
+ # Fallback: Try removing the input prompt string itself (less reliable)
940
+ if prompt in generated_text:
941
+ news_article = generated_text.replace(prompt, "", 1).strip()
942
+ else:
943
+ # If prompt not found exactly, assume the output is only the generation
944
+ # This might happen if the pipeline handles prompt removal internally sometimes
945
+ news_article = generated_text
946
+ logger.warning("Prompt marker '[/INST]' not found in LLM output. Returning full output.")
947
+
948
+
949
+ generation_time = time.time() - generation_start_time
950
+ logger.info(f"News generation completed in {generation_time:.2f} seconds. Output length: {len(news_article)} characters.")
951
+
952
+ except torch.cuda.OutOfMemoryError as oom_error:
953
+ logger.error(f"CUDA Out of Memory error during LLM generation: {oom_error}")
954
+ logger.error(traceback.format_exc())
955
+ model_manager.reset_models(force=True) # Attempt to recover
956
+ raise RuntimeError("Generation failed due to insufficient GPU memory. Please try reducing article size or complexity.") from oom_error
957
  except Exception as gen_error:
958
+ logger.error(f"Error during text generation pipeline: {str(gen_error)}")
959
+ logger.error(traceback.format_exc())
960
+ raise RuntimeError(f"LLM generation failed: {gen_error}") from gen_error
961
+
962
+ total_time = time.time() - request_start_time
963
+ logger.info(f"Total request processing time: {total_time:.2f} seconds.")
964
+
965
+ # Return the generated article and the log of raw transcriptions
966
+ return news_article, raw_transcriptions.strip()
967
 
968
  except Exception as e:
969
+ total_time = time.time() - request_start_time
970
+ logger.error(f"Error in generate_news function after {total_time:.2f} seconds: {str(e)}")
971
+ logger.error(traceback.format_exc())
972
+ # Attempt to reset models to recover state if possible
973
  try:
974
  model_manager.reset_models(force=True)
975
  except Exception as reset_error:
976
+ logger.error(f"Failed to reset models after error: {str(reset_error)}")
977
+ # Return error messages to the UI
978
+ error_message = f"Error generating the news article: {str(e)}"
979
+ transcription_log = raw_transcriptions.strip() + f"\n\n[ERROR] News generation failed: {str(e)}"
980
+ return error_message, transcription_log
981
 
982
  def create_demo():
983
+ """Creates the Gradio interface"""
984
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
985
  gr.Markdown("# πŸ“° NewsIA - AI News Generator")
986
+ gr.Markdown("Create professional news articles from multiple information sources.")
987
+
988
+ # Store all input components for easy access/reset
989
+ all_inputs = []
990
+
991
  with gr.Row():
992
  with gr.Column(scale=2):
993
  instructions = gr.Textbox(
994
+ label="Instructions for the News Article",
995
+ placeholder="Enter specific instructions for generating your news article (e.g., focus on the economic impact)",
996
+ lines=2,
997
+ value=""
998
  )
999
+ all_inputs.append(instructions)
1000
+
1001
  facts = gr.Textbox(
1002
+ label="Main Facts",
1003
+ placeholder="Describe the most important facts the news should include (e.g., Event name, date, location, key people involved)",
1004
+ lines=4,
1005
+ value=""
1006
  )
1007
+ all_inputs.append(facts)
1008
+
1009
  with gr.Row():
1010
+ size_slider = gr.Slider(
1011
  label="Approximate Length (words)",
1012
  minimum=100,
1013
+ maximum=700, # Increased max size
1014
  value=250,
1015
  step=50
1016
  )
1017
+ all_inputs.append(size_slider)
1018
+
1019
+ tone_dropdown = gr.Dropdown(
1020
+ label="Tone of the News Article",
1021
+ choices=["neutral", "serious", "formal", "urgent", "investigative", "human-interest", "lighthearted"],
1022
  value="neutral"
1023
  )
1024
+ all_inputs.append(tone_dropdown)
1025
 
1026
  with gr.Column(scale=3):
 
 
 
1027
  with gr.Tabs():
1028
  with gr.TabItem("πŸ“ Documents"):
1029
+ gr.Markdown("Upload relevant documents (PDF, DOCX, XLSX, CSV). Max 5.")
1030
+ doc_inputs = []
1031
  for i in range(1, 6):
1032
+ doc_file = gr.File(
1033
  label=f"Document {i}",
1034
+ file_types=["pdf", ".docx", ".xlsx", ".csv"], # Explicit extensions for clarity
1035
+ file_count="single" # Ensure single file per component
1036
  )
1037
+ doc_inputs.append(doc_file)
1038
+ all_inputs.extend(doc_inputs)
1039
 
1040
  with gr.TabItem("πŸ”Š Audio/Video"):
1041
+ gr.Markdown("Upload audio or video files for transcription (MP3, WAV, MP4, MOV, etc.). Max 5 sources.")
1042
+ audio_video_inputs = []
1043
+ for i in range(1, 6):
1044
  with gr.Group():
1045
  gr.Markdown(f"**Source {i}**")
1046
+ audio_file = gr.File(
1047
+ label=f"Audio/Video File {i}",
1048
  file_types=["audio", "video"]
1049
  )
1050
  with gr.Row():
1051
+ speaker_name = gr.Textbox(
1052
+ label="Speaker Name",
1053
+ placeholder="Name of the interviewee or speaker",
1054
+ value=""
1055
  )
1056
+ speaker_role = gr.Textbox(
1057
+ label="Role/Position",
1058
+ placeholder="Speaker's title or role",
1059
+ value=""
1060
  )
1061
+ audio_video_inputs.append(audio_file)
1062
+ audio_video_inputs.append(speaker_name)
1063
+ audio_video_inputs.append(speaker_role)
1064
+ all_inputs.extend(audio_video_inputs)
1065
 
1066
  with gr.TabItem("🌐 URLs"):
1067
+ gr.Markdown("Add URLs to relevant web pages or articles. Max 5.")
1068
+ url_inputs = []
1069
+ for i in range(1, 6):
1070
+ url_textbox = gr.Textbox(
1071
  label=f"URL {i}",
1072
+ placeholder="https://example.com/article",
1073
+ value=""
1074
  )
1075
+ url_inputs.append(url_textbox)
1076
+ all_inputs.extend(url_inputs)
1077
 
1078
  with gr.TabItem("πŸ“± Social Media"):
1079
+ gr.Markdown("Add URLs to social media posts (e.g., Twitter, YouTube, TikTok). Max 3.")
1080
+ social_inputs = []
1081
+ for i in range(1, 4):
1082
  with gr.Group():
1083
+ gr.Markdown(f"**Social Media Source {i}**")
1084
+ social_url_textbox = gr.Textbox(
1085
+ label=f"Post URL",
1086
+ placeholder="https://twitter.com/user/status/...",
1087
+ value=""
1088
  )
1089
  with gr.Row():
1090
+ social_name_textbox = gr.Textbox(
1091
+ label=f"Account Name/User",
1092
+ placeholder="Name or handle (e.g., @username)",
1093
+ value=""
1094
  )
1095
+ social_context_textbox = gr.Textbox(
1096
+ label=f"Context",
1097
+ placeholder="Brief context (e.g., statement on event X)",
1098
+ value=""
1099
  )
1100
+ social_inputs.append(social_url_textbox)
1101
+ social_inputs.append(social_name_textbox)
1102
+ social_inputs.append(social_context_textbox)
1103
+ all_inputs.extend(social_inputs)
1104
+
1105
 
1106
  with gr.Row():
1107
+ generate_button = gr.Button("✨ Generate News Article", variant="primary")
1108
+ clear_button = gr.Button("πŸ”„ Clear All Inputs")
1109
 
1110
  with gr.Tabs():
1111
+ with gr.TabItem("πŸ“„ Generated News Article"):
1112
  news_output = gr.Textbox(
1113
+ label="Draft News Article",
1114
+ lines=20, # Increased lines
1115
+ show_copy_button=True,
1116
+ value=""
1117
  )
1118
+ with gr.TabItem("πŸŽ™οΈ Source Transcriptions & Logs"):
 
1119
  transcriptions_output = gr.Textbox(
1120
+ label="Transcriptions and Processing Log",
1121
+ lines=15, # Increased lines
1122
+ show_copy_button=True,
1123
+ value=""
1124
  )
1125
 
1126
+ # --- Event Handlers ---
1127
+ # Define outputs
1128
+ outputs_list = [news_output, transcriptions_output]
1129
+
1130
+ # Generate button click
1131
+ generate_button.click(
1132
  fn=generate_news,
1133
+ inputs=all_inputs, # Pass the consolidated list
1134
+ outputs=outputs_list
1135
  )
1136
+
1137
+ # Clear button click
1138
+ def clear_all_inputs_and_outputs():
1139
+ # Return a list of default values matching the number and type of inputs + outputs
1140
+ reset_values = []
1141
+ for input_comp in all_inputs:
1142
+ # Default for Textbox, Dropdown is "", for Slider is its default, for File is None
1143
+ if isinstance(input_comp, (gr.Textbox, gr.Dropdown)):
1144
+ reset_values.append("")
1145
+ elif isinstance(input_comp, gr.Slider):
1146
+ # Find the original default value if needed, or just use a sensible default
1147
+ reset_values.append(250) # Reset slider to default
1148
+ elif isinstance(input_comp, gr.File):
1149
+ reset_values.append(None)
1150
+ else:
1151
+ reset_values.append(None) # Default for unknown/other types
1152
+
1153
+ # Add default values for the output fields
1154
+ reset_values.extend(["", ""]) # Two Textbox outputs
1155
+
1156
+ # Also reset the models in the background
1157
+ model_manager.reset_models(force=True)
1158
+ logger.info("UI cleared and models reset.")
1159
+
1160
+ return reset_values
1161
+
1162
+ clear_button.click(
1163
+ fn=clear_all_inputs_and_outputs,
1164
+ inputs=None, # No inputs needed for the clear function itself
1165
+ outputs=all_inputs + outputs_list # The list of components to clear
1166
  )
1167
 
1168
+ # Add event handler to reset models when the Gradio app closes or reloads (if possible)
1169
+ # demo.unload(model_manager.reset_models, inputs=None, outputs=None) # Might not work reliably in Spaces
1170
+
1171
  return demo
1172
 
1173
  if __name__ == "__main__":
1174
+ logger.info("Starting NewsIA application...")
1175
+
1176
+ # Optional: Pre-initialize Whisper on startup if desired and resources allow
1177
+ # This can make the first transcription faster but uses GPU resources immediately.
1178
+ # Consider enabling only if transcriptions are very common.
1179
+ # try:
1180
+ # logger.info("Attempting to pre-initialize Whisper model...")
1181
+ # model_manager.initialize_whisper()
1182
+ # except Exception as e:
1183
+ # logger.warning(f"Pre-initialization of Whisper model failed (will load on demand): {str(e)}")
1184
+
1185
+ # Create the Gradio Demo
1186
+ news_demo = create_demo()
1187
+
1188
+ # Configure the queue - remove concurrency_count and max_size
1189
+ # Use default queue settings, suitable for most Spaces environments
1190
+ news_demo.queue()
1191
+
1192
+ # Launch the Gradio app
1193
+ logger.info("Launching Gradio interface...")
1194
+ news_demo.launch(
1195
+ server_name="0.0.0.0", # Necessary for Docker/Spaces
1196
+ server_port=7860,
1197
+ # share=True # Share=True is often handled by Spaces automatically, can be removed
1198
+ # debug=True # Enable for more detailed Gradio logs if needed
1199
+ )
1200
+ logger.info("NewsIA application finished.")