CamiloVega commited on
Commit
f1d02c3
Β·
verified Β·
1 Parent(s): f84674f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -187
app.py CHANGED
@@ -17,7 +17,8 @@ from functools import lru_cache
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
21
 
22
  # Configure logging
23
  logging.basicConfig(
@@ -51,47 +52,34 @@ class ModelManager:
51
 
52
  @spaces.GPU()
53
  def initialize_llm(self):
54
- """Initialize LLM model with standard transformers"""
55
  try:
56
- # Use small model for ZeroGPU compatibility
57
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
58
 
59
- logger.info("Loading tokenizer...")
60
- self.tokenizer = AutoTokenizer.from_pretrained(
61
- MODEL_NAME,
62
- token=HUGGINGFACE_TOKEN,
63
- use_fast=True
 
 
64
  )
65
 
66
- if self.tokenizer.pad_token is None:
67
- self.tokenizer.pad_token = self.tokenizer.eos_token
68
-
69
- # Basic memory settings for ZeroGPU
70
- logger.info("Loading model...")
71
- self.model = AutoModelForCausalLM.from_pretrained(
72
- MODEL_NAME,
73
- token=HUGGINGFACE_TOKEN,
74
- device_map="auto",
75
- torch_dtype=torch.float16,
76
- low_cpu_mem_usage=True,
77
- # Optimizations for ZeroGPU
78
- max_memory={0: "4GB"},
79
- offload_folder="offload",
80
- offload_state_dict=True
81
- )
82
-
83
- # Create text generation pipeline
84
- logger.info("Creating pipeline...")
85
- self.pipeline = pipeline(
86
- "text-generation",
87
- model=self.model,
88
- tokenizer=self.tokenizer,
89
- torch_dtype=torch.float16,
90
- device_map="auto",
91
- max_length=1024
92
  )
93
 
94
- logger.info("LLM initialized successfully")
95
  self.last_used = time.time()
96
  return True
97
 
@@ -101,14 +89,15 @@ class ModelManager:
101
 
102
  @spaces.GPU()
103
  def initialize_whisper(self):
104
- """Initialize Whisper model for audio transcription"""
105
  try:
106
  logger.info("Loading Whisper model...")
107
- # Using tiny model for efficiency but can be changed based on needs
108
  self.whisper_model = whisper.load_model(
109
  "tiny",
110
  device="cuda" if torch.cuda.is_available() else "cpu",
111
- download_root="/tmp/whisper"
 
112
  )
113
  logger.info("Whisper model initialized successfully")
114
  self.last_used = time.time()
@@ -119,7 +108,7 @@ class ModelManager:
119
 
120
  def check_llm_initialized(self):
121
  """Check if LLM is initialized and initialize if needed"""
122
- if self.tokenizer is None or self.model is None or self.pipeline is None:
123
  logger.info("LLM not initialized, initializing...")
124
  self.initialize_llm()
125
  self.last_used = time.time()
@@ -134,26 +123,21 @@ class ModelManager:
134
  def reset_models(self, force=False):
135
  """Reset models to free memory if they haven't been used recently"""
136
  current_time = time.time()
137
- # Only reset if forced or models haven't been used for 10 minutes
138
  if force or (current_time - self.last_used > 600):
139
  try:
140
  logger.info("Resetting models to free memory...")
141
 
142
- if hasattr(self, 'model') and self.model is not None:
143
  del self.model
144
 
145
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
146
  del self.tokenizer
147
 
148
- if hasattr(self, 'pipeline') and self.pipeline is not None:
149
- del self.pipeline
150
-
151
- if hasattr(self, 'whisper_model') and self.whisper_model is not None:
152
  del self.whisper_model
153
 
154
  self.tokenizer = None
155
  self.model = None
156
- self.pipeline = None
157
  self.whisper_model = None
158
 
159
  if torch.cuda.is_available():
@@ -166,7 +150,6 @@ class ModelManager:
166
  except Exception as e:
167
  logger.error(f"Error resetting models: {str(e)}")
168
 
169
- # Create global model manager instance
170
  model_manager = ModelManager()
171
 
172
  @lru_cache(maxsize=32)
@@ -197,7 +180,6 @@ def convert_video_to_audio(video_file):
197
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
198
  output_file = temp_file.name
199
 
200
- # Use ffmpeg directly via subprocess
201
  command = [
202
  "ffmpeg",
203
  "-i", video_file,
@@ -205,7 +187,7 @@ def convert_video_to_audio(video_file):
205
  "-map", "a",
206
  "-vn",
207
  output_file,
208
- "-y" # Overwrite output file if it exists
209
  ]
210
 
211
  subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -239,10 +221,10 @@ def transcribe_audio(file):
239
  file_path = download_social_media_video(file)
240
  elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
241
  file_path = convert_video_to_audio(file)
242
- elif file is not None: # Handle file object from Gradio
243
  file_path = preprocess_audio(file.name)
244
  else:
245
- return "" # Return empty string for None input
246
 
247
  logger.info(f"Transcribing audio: {file_path}")
248
  if not os.path.exists(file_path):
@@ -250,13 +232,10 @@ def transcribe_audio(file):
250
 
251
  with torch.inference_mode():
252
  result = model_manager.whisper_model.transcribe(file_path)
253
- if not result:
254
- raise RuntimeError("Transcription failed to produce results")
255
 
256
  transcription = result.get("text", "Error in transcription")
257
  logger.info(f"Transcription completed: {transcription[:50]}...")
258
 
259
- # Clean up temp file
260
  try:
261
  if os.path.exists(file_path):
262
  os.remove(file_path)
@@ -302,22 +281,19 @@ def read_url(url):
302
  response.raise_for_status()
303
  soup = BeautifulSoup(response.content, 'html.parser')
304
 
305
- # Remove non-content elements
306
  for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]):
307
  element.extract()
308
 
309
- # Extract main content
310
  main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"])
311
  if main_content:
312
  text = main_content.get_text(separator='\n', strip=True)
313
  else:
314
  text = soup.get_text(separator='\n', strip=True)
315
 
316
- # Clean up whitespace
317
  lines = [line.strip() for line in text.split('\n') if line.strip()]
318
  text = '\n'.join(lines)
319
 
320
- return text[:10000] # Limit to 10k chars to avoid huge inputs
321
  except Exception as e:
322
  logger.error(f"Error reading URL: {str(e)}")
323
  return f"Error reading URL: {str(e)}"
@@ -347,16 +323,13 @@ def process_social_content(url):
347
  def generate_news(instructions, facts, size, tone, *args):
348
  """Generate a news article based on provided data"""
349
  try:
350
- # Ensure size is integer
351
  if isinstance(size, float):
352
  size = int(size)
353
  elif not isinstance(size, int):
354
- size = 250 # Default size
355
 
356
- # Check if models are initialized
357
  model_manager.check_llm_initialized()
358
 
359
- # Prepare data structure for inputs
360
  knowledge_base = {
361
  "instructions": instructions or "",
362
  "facts": facts or "",
@@ -366,15 +339,12 @@ def generate_news(instructions, facts, size, tone, *args):
366
  "social_content": []
367
  }
368
 
369
- # Define the indices for parsing args
370
  num_audios = 5 * 3
371
  num_social_urls = 3 * 3
372
  num_urls = 5
373
 
374
- # Parse arguments
375
- args = list(args) # Convert tuple to list for easier manipulation
376
 
377
- # Ensure we have enough arguments
378
  while len(args) < (num_audios + num_social_urls + num_urls + 5):
379
  args.append("")
380
 
@@ -383,7 +353,6 @@ def generate_news(instructions, facts, size, tone, *args):
383
  urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
384
  documents = args[num_audios+num_social_urls+num_urls:]
385
 
386
- # Process URLs with progress reporting
387
  logger.info("Processing URLs...")
388
  for url in urls:
389
  if url and isinstance(url, str) and url.strip():
@@ -391,7 +360,6 @@ def generate_news(instructions, facts, size, tone, *args):
391
  if content and not content.startswith("Error"):
392
  knowledge_base["url_content"].append(content)
393
 
394
- # Process documents
395
  logger.info("Processing documents...")
396
  for document in documents:
397
  if document and hasattr(document, 'name'):
@@ -399,10 +367,9 @@ def generate_news(instructions, facts, size, tone, *args):
399
  if content and not content.startswith("Error"):
400
  knowledge_base["document_content"].append(content)
401
 
402
- # Process audio/video files
403
  logger.info("Processing audio/video files...")
404
  for i in range(0, len(audios), 3):
405
- if i+2 < len(audios): # Ensure we have complete set of 3 elements
406
  audio_file, name, position = audios[i:i+3]
407
  if audio_file and hasattr(audio_file, 'name'):
408
  knowledge_base["audio_data"].append({
@@ -411,10 +378,9 @@ def generate_news(instructions, facts, size, tone, *args):
411
  "position": position or "Not specified"
412
  })
413
 
414
- # Process social media content
415
  logger.info("Processing social media content...")
416
  for i in range(0, len(social_urls), 3):
417
- if i+2 < len(social_urls): # Ensure we have complete set of 3 elements
418
  social_url, social_name, social_context = social_urls[i:i+3]
419
  if social_url and isinstance(social_url, str) and social_url.strip():
420
  social_content = process_social_content(social_url)
@@ -427,11 +393,9 @@ def generate_news(instructions, facts, size, tone, *args):
427
  "video": social_content.get("video", "")
428
  })
429
 
430
- # Prepare transcriptions text
431
  transcriptions_text = ""
432
  raw_transcriptions = ""
433
 
434
- # Process audio data transcriptions
435
  logger.info("Transcribing audio...")
436
  for idx, data in enumerate(knowledge_base["audio_data"]):
437
  if data["audio"] is not None:
@@ -440,10 +404,8 @@ def generate_news(instructions, facts, size, tone, *args):
440
  transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n'
441
  raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
442
 
443
- # Process social media content transcriptions
444
  for idx, data in enumerate(knowledge_base["social_content"]):
445
  if data["text"] and not str(data["text"]).startswith("Error"):
446
- # Truncate long texts for the prompt
447
  text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"]
448
  social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n'
449
  transcriptions_text += social_text
@@ -454,10 +416,8 @@ def generate_news(instructions, facts, size, tone, *args):
454
  transcriptions_text += video_transcription
455
  raw_transcriptions += video_transcription
456
 
457
- # Combine document content and URL content (with truncation for very long content)
458
  document_summaries = []
459
  for idx, doc in enumerate(knowledge_base["document_content"]):
460
- # Truncate long documents
461
  if len(doc) > 1000:
462
  doc_excerpt = doc[:1000] + "... [document continues]"
463
  else:
@@ -468,7 +428,6 @@ def generate_news(instructions, facts, size, tone, *args):
468
 
469
  url_summaries = []
470
  for idx, url_content in enumerate(knowledge_base["url_content"]):
471
- # Truncate long URL content
472
  if len(url_content) > 1000:
473
  url_excerpt = url_content[:1000] + "... [content continues]"
474
  else:
@@ -477,7 +436,6 @@ def generate_news(instructions, facts, size, tone, *args):
477
 
478
  url_content = "\n\n".join(url_summaries)
479
 
480
- # Create prompt for the model
481
  prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information:
482
 
483
  Instructions: {knowledge_base["instructions"]}
@@ -504,35 +462,33 @@ Follow these requirements:
504
  - Do not invent information
505
  - Be rigorous with the provided facts [/INST]"""
506
 
507
- # Generate with standard pipeline
508
  try:
509
  logger.info("Generating news article...")
510
 
511
- # Set max length based on requested size
512
  max_length = min(len(prompt.split()) + size * 2, 2048)
513
 
514
- # Generate using the pipeline
515
- outputs = model_manager.pipeline(
516
  prompt,
517
- max_length=max_length,
518
- do_sample=True,
519
- temperature=0.7,
520
- top_p=0.95,
521
- repetition_penalty=1.2,
522
- pad_token_id=model_manager.tokenizer.eos_token_id,
523
- num_return_sequences=1
 
 
 
 
 
524
  )
525
 
526
- # Extract generated text
527
- generated_text = outputs[0]['generated_text']
528
 
529
- # Clean up the result by removing the prompt
530
  if "[/INST]" in generated_text:
531
  news_article = generated_text.split("[/INST]")[1].strip()
532
  else:
533
- # Try to extract the text after the prompt
534
- prompt_words = prompt.split()[:50] # Use first 50 words to identify
535
- prompt_fragment = " ".join(prompt_words)
536
  if prompt_fragment in generated_text:
537
  news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
538
  else:
@@ -549,173 +505,149 @@ Follow these requirements:
549
  except Exception as e:
550
  logger.error(f"Error generating news: {str(e)}")
551
  try:
552
- # Reset models to recover from errors
553
  model_manager.reset_models(force=True)
554
  except Exception as reset_error:
555
  logger.error(f"Failed to reset models: {str(reset_error)}")
556
- return f"Error generando la noticia: {str(e)}", "Error procesando las transcripciones."
557
 
558
  def create_demo():
559
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
560
- gr.Markdown("# πŸ“° NewsIA - Generador de Noticias IA")
561
- gr.Markdown("Crea noticias profesionales a partir de mΓΊltiples fuentes de informaciΓ³n.")
562
 
563
  with gr.Row():
564
  with gr.Column(scale=2):
565
- instrucciones = gr.Textbox(
566
- label="Instrucciones para la noticia",
567
- placeholder="Escribe instrucciones especΓ­ficas para la generaciΓ³n de tu noticia",
568
- lines=2,
569
- value=""
570
  )
571
- hechos = gr.Textbox(
572
- label="Hechos principales",
573
- placeholder="Describe los hechos mΓ‘s importantes que debe incluir la noticia",
574
- lines=4,
575
- value=""
576
  )
577
 
578
  with gr.Row():
579
- tamaΓ±o = gr.Slider(
580
- label="Longitud aproximada (palabras)",
581
  minimum=100,
582
  maximum=500,
583
  value=250,
584
  step=50
585
  )
586
- tono = gr.Dropdown(
587
- label="Tono de la noticia",
588
- choices=["serio", "neutral", "divertido", "formal", "informal", "urgente"],
589
  value="neutral"
590
  )
591
 
592
  with gr.Column(scale=3):
593
- # Inicializamos la lista de inputs con valores conocidos
594
  inputs_list = []
595
- inputs_list.append(instrucciones)
596
- inputs_list.append(hechos)
597
- inputs_list.append(tamaΓ±o)
598
- inputs_list.append(tono)
599
 
600
  with gr.Tabs():
601
- with gr.TabItem("πŸ“ Documentos"):
602
- documentos = []
603
- for i in range(1, 6): # Mantenemos 5 documentos como en el original
604
- documento = gr.File(
605
- label=f"Documento {i}",
606
  file_types=["pdf", "docx", "xlsx", "csv"],
607
- file_count="single",
608
- value=None
609
  )
610
- documentos.append(documento)
611
- inputs_list.append(documento)
612
 
613
  with gr.TabItem("πŸ”Š Audio/Video"):
614
- for i in range(1, 6): # Mantenemos 5 fuentes como en el original
615
  with gr.Group():
616
- gr.Markdown(f"**Fuente {i}**")
617
  file = gr.File(
618
  label=f"Audio/Video {i}",
619
- file_types=["audio", "video"],
620
- value=None
621
  )
622
  with gr.Row():
623
- nombre = gr.Textbox(
624
- label="Nombre",
625
- placeholder="Nombre del entrevistado",
626
- value=""
627
  )
628
- cargo = gr.Textbox(
629
- label="Cargo/Rol",
630
- placeholder="Cargo o rol",
631
- value=""
632
  )
633
- inputs_list.append(file)
634
- inputs_list.append(nombre)
635
- inputs_list.append(cargo)
636
 
637
  with gr.TabItem("🌐 URLs"):
638
- for i in range(1, 6): # Mantenemos 5 URLs como en el original
639
  url = gr.Textbox(
640
  label=f"URL {i}",
641
- placeholder="https://...",
642
- value=""
643
  )
644
  inputs_list.append(url)
645
 
646
- with gr.TabItem("πŸ“± Redes Sociales"):
647
- for i in range(1, 4): # Mantenemos 3 redes sociales como en el original
648
  with gr.Group():
649
- gr.Markdown(f"**Red Social {i}**")
650
  social_url = gr.Textbox(
651
- label=f"URL",
652
- placeholder="https://...",
653
- value=""
654
  )
655
  with gr.Row():
656
- social_nombre = gr.Textbox(
657
- label=f"Nombre/Cuenta",
658
- placeholder="Nombre de la persona o cuenta",
659
- value=""
660
  )
661
- social_contexto = gr.Textbox(
662
- label=f"Contexto",
663
- placeholder="Contexto relevante",
664
- value=""
665
  )
666
- inputs_list.append(social_url)
667
- inputs_list.append(social_nombre)
668
- inputs_list.append(social_contexto)
669
 
670
  with gr.Row():
671
- generar = gr.Button("✨ Generar Noticia", variant="primary")
672
- reset = gr.Button("πŸ”„ Limpiar Todo")
673
 
674
  with gr.Tabs():
675
- with gr.TabItem("πŸ“„ Noticia Generada"):
676
- noticia_output = gr.Textbox(
677
- label="Borrador de la noticia",
678
  lines=15,
679
- show_copy_button=True,
680
- value=""
681
  )
682
 
683
- with gr.TabItem("πŸŽ™οΈ Transcripciones"):
684
- transcripciones_output = gr.Textbox(
685
- label="Transcripciones de fuentes",
686
  lines=10,
687
- show_copy_button=True,
688
- value=""
689
  )
690
 
691
- # Set up event handlers
692
- generar.click(
693
  fn=generate_news,
694
  inputs=inputs_list,
695
- outputs=[noticia_output, transcripciones_output]
696
  )
697
 
698
- # Reset functionality to clear all inputs
699
  def reset_all():
700
- return [""] * len(inputs_list) + ["", ""]
701
 
702
- reset.click(
703
  fn=reset_all,
704
  inputs=None,
705
- outputs=inputs_list + [noticia_output, transcripciones_output]
706
  )
707
 
708
  return demo
709
 
710
  if __name__ == "__main__":
711
  try:
712
- # Try initializing whisper model on startup
713
  model_manager.initialize_whisper()
714
  except Exception as e:
715
  logger.warning(f"Initial whisper model loading failed: {str(e)}")
716
 
717
  demo = create_demo()
718
- demo.queue(concurrency_count=1, max_size=5)
719
  demo.launch(
720
  share=True,
721
  server_name="0.0.0.0",
 
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
+ from unsloth import FastLanguageModel
21
+ from transformers import AutoTokenizer
22
 
23
  # Configure logging
24
  logging.basicConfig(
 
52
 
53
  @spaces.GPU()
54
  def initialize_llm(self):
55
+ """Initialize LLM model with Unsloth optimization"""
56
  try:
 
57
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
58
 
59
+ logger.info("Loading Unsloth-optimized model...")
60
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
61
+ model_name = MODEL_NAME,
62
+ max_seq_length = 2048,
63
+ dtype = torch.float16,
64
+ load_in_4bit = True,
65
+ token = HUGGINGFACE_TOKEN,
66
  )
67
 
68
+ # Enable LoRA for better ZeroGPU performance
69
+ self.model = FastLanguageModel.get_peft_model(
70
+ self.model,
71
+ r = 16,
72
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
73
+ "gate_proj", "up_proj", "down_proj"],
74
+ lora_alpha = 16,
75
+ lora_dropout = 0,
76
+ bias = "none",
77
+ use_gradient_checkpointing = True,
78
+ random_state = 3407,
79
+ max_seq_length = 2048,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
 
82
+ logger.info("LLM initialized successfully with Unsloth")
83
  self.last_used = time.time()
84
  return True
85
 
 
89
 
90
  @spaces.GPU()
91
  def initialize_whisper(self):
92
+ """Initialize Whisper model with safety fix"""
93
  try:
94
  logger.info("Loading Whisper model...")
95
+ # Load with weights_only=True for security
96
  self.whisper_model = whisper.load_model(
97
  "tiny",
98
  device="cuda" if torch.cuda.is_available() else "cpu",
99
+ download_root="/tmp/whisper",
100
+ weights_only=True # Security fix
101
  )
102
  logger.info("Whisper model initialized successfully")
103
  self.last_used = time.time()
 
108
 
109
  def check_llm_initialized(self):
110
  """Check if LLM is initialized and initialize if needed"""
111
+ if self.tokenizer is None or self.model is None:
112
  logger.info("LLM not initialized, initializing...")
113
  self.initialize_llm()
114
  self.last_used = time.time()
 
123
  def reset_models(self, force=False):
124
  """Reset models to free memory if they haven't been used recently"""
125
  current_time = time.time()
 
126
  if force or (current_time - self.last_used > 600):
127
  try:
128
  logger.info("Resetting models to free memory...")
129
 
130
+ if self.model is not None:
131
  del self.model
132
 
133
+ if self.tokenizer is not None:
134
  del self.tokenizer
135
 
136
+ if self.whisper_model is not None:
 
 
 
137
  del self.whisper_model
138
 
139
  self.tokenizer = None
140
  self.model = None
 
141
  self.whisper_model = None
142
 
143
  if torch.cuda.is_available():
 
150
  except Exception as e:
151
  logger.error(f"Error resetting models: {str(e)}")
152
 
 
153
  model_manager = ModelManager()
154
 
155
  @lru_cache(maxsize=32)
 
180
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
181
  output_file = temp_file.name
182
 
 
183
  command = [
184
  "ffmpeg",
185
  "-i", video_file,
 
187
  "-map", "a",
188
  "-vn",
189
  output_file,
190
+ "-y"
191
  ]
192
 
193
  subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
221
  file_path = download_social_media_video(file)
222
  elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
223
  file_path = convert_video_to_audio(file)
224
+ elif file is not None:
225
  file_path = preprocess_audio(file.name)
226
  else:
227
+ return ""
228
 
229
  logger.info(f"Transcribing audio: {file_path}")
230
  if not os.path.exists(file_path):
 
232
 
233
  with torch.inference_mode():
234
  result = model_manager.whisper_model.transcribe(file_path)
 
 
235
 
236
  transcription = result.get("text", "Error in transcription")
237
  logger.info(f"Transcription completed: {transcription[:50]}...")
238
 
 
239
  try:
240
  if os.path.exists(file_path):
241
  os.remove(file_path)
 
281
  response.raise_for_status()
282
  soup = BeautifulSoup(response.content, 'html.parser')
283
 
 
284
  for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]):
285
  element.extract()
286
 
 
287
  main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"])
288
  if main_content:
289
  text = main_content.get_text(separator='\n', strip=True)
290
  else:
291
  text = soup.get_text(separator='\n', strip=True)
292
 
 
293
  lines = [line.strip() for line in text.split('\n') if line.strip()]
294
  text = '\n'.join(lines)
295
 
296
+ return text[:10000]
297
  except Exception as e:
298
  logger.error(f"Error reading URL: {str(e)}")
299
  return f"Error reading URL: {str(e)}"
 
323
  def generate_news(instructions, facts, size, tone, *args):
324
  """Generate a news article based on provided data"""
325
  try:
 
326
  if isinstance(size, float):
327
  size = int(size)
328
  elif not isinstance(size, int):
329
+ size = 250
330
 
 
331
  model_manager.check_llm_initialized()
332
 
 
333
  knowledge_base = {
334
  "instructions": instructions or "",
335
  "facts": facts or "",
 
339
  "social_content": []
340
  }
341
 
 
342
  num_audios = 5 * 3
343
  num_social_urls = 3 * 3
344
  num_urls = 5
345
 
346
+ args = list(args)
 
347
 
 
348
  while len(args) < (num_audios + num_social_urls + num_urls + 5):
349
  args.append("")
350
 
 
353
  urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
354
  documents = args[num_audios+num_social_urls+num_urls:]
355
 
 
356
  logger.info("Processing URLs...")
357
  for url in urls:
358
  if url and isinstance(url, str) and url.strip():
 
360
  if content and not content.startswith("Error"):
361
  knowledge_base["url_content"].append(content)
362
 
 
363
  logger.info("Processing documents...")
364
  for document in documents:
365
  if document and hasattr(document, 'name'):
 
367
  if content and not content.startswith("Error"):
368
  knowledge_base["document_content"].append(content)
369
 
 
370
  logger.info("Processing audio/video files...")
371
  for i in range(0, len(audios), 3):
372
+ if i+2 < len(audios):
373
  audio_file, name, position = audios[i:i+3]
374
  if audio_file and hasattr(audio_file, 'name'):
375
  knowledge_base["audio_data"].append({
 
378
  "position": position or "Not specified"
379
  })
380
 
 
381
  logger.info("Processing social media content...")
382
  for i in range(0, len(social_urls), 3):
383
+ if i+2 < len(social_urls):
384
  social_url, social_name, social_context = social_urls[i:i+3]
385
  if social_url and isinstance(social_url, str) and social_url.strip():
386
  social_content = process_social_content(social_url)
 
393
  "video": social_content.get("video", "")
394
  })
395
 
 
396
  transcriptions_text = ""
397
  raw_transcriptions = ""
398
 
 
399
  logger.info("Transcribing audio...")
400
  for idx, data in enumerate(knowledge_base["audio_data"]):
401
  if data["audio"] is not None:
 
404
  transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n'
405
  raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
406
 
 
407
  for idx, data in enumerate(knowledge_base["social_content"]):
408
  if data["text"] and not str(data["text"]).startswith("Error"):
 
409
  text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"]
410
  social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n'
411
  transcriptions_text += social_text
 
416
  transcriptions_text += video_transcription
417
  raw_transcriptions += video_transcription
418
 
 
419
  document_summaries = []
420
  for idx, doc in enumerate(knowledge_base["document_content"]):
 
421
  if len(doc) > 1000:
422
  doc_excerpt = doc[:1000] + "... [document continues]"
423
  else:
 
428
 
429
  url_summaries = []
430
  for idx, url_content in enumerate(knowledge_base["url_content"]):
 
431
  if len(url_content) > 1000:
432
  url_excerpt = url_content[:1000] + "... [content continues]"
433
  else:
 
436
 
437
  url_content = "\n\n".join(url_summaries)
438
 
 
439
  prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information:
440
 
441
  Instructions: {knowledge_base["instructions"]}
 
462
  - Do not invent information
463
  - Be rigorous with the provided facts [/INST]"""
464
 
 
465
  try:
466
  logger.info("Generating news article...")
467
 
 
468
  max_length = min(len(prompt.split()) + size * 2, 2048)
469
 
470
+ inputs = model_manager.tokenizer(
 
471
  prompt,
472
+ return_tensors = "pt",
473
+ padding = True,
474
+ truncation = True,
475
+ max_length = 2048,
476
+ ).to("cuda")
477
+
478
+ outputs = model_manager.model.generate(
479
+ **inputs,
480
+ max_new_tokens = size + 100,
481
+ temperature = 0.7,
482
+ do_sample = True,
483
+ pad_token_id = model_manager.tokenizer.eos_token_id,
484
  )
485
 
486
+ generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens = True)
 
487
 
 
488
  if "[/INST]" in generated_text:
489
  news_article = generated_text.split("[/INST]")[1].strip()
490
  else:
491
+ prompt_fragment = " ".join(prompt.split()[:50])
 
 
492
  if prompt_fragment in generated_text:
493
  news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
494
  else:
 
505
  except Exception as e:
506
  logger.error(f"Error generating news: {str(e)}")
507
  try:
 
508
  model_manager.reset_models(force=True)
509
  except Exception as reset_error:
510
  logger.error(f"Failed to reset models: {str(reset_error)}")
511
+ return f"Error generating news: {str(e)}", "Error processing transcriptions."
512
 
513
  def create_demo():
514
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
515
+ gr.Markdown("# πŸ“° NewsIA - AI News Generator")
516
+ gr.Markdown("Create professional news articles from multiple sources.")
517
 
518
  with gr.Row():
519
  with gr.Column(scale=2):
520
+ instructions = gr.Textbox(
521
+ label="News Instructions",
522
+ placeholder="Enter specific instructions for news generation",
523
+ lines=2
 
524
  )
525
+ facts = gr.Textbox(
526
+ label="Key Facts",
527
+ placeholder="Describe the most important facts to include",
528
+ lines=4
 
529
  )
530
 
531
  with gr.Row():
532
+ size = gr.Slider(
533
+ label="Approximate Length (words)",
534
  minimum=100,
535
  maximum=500,
536
  value=250,
537
  step=50
538
  )
539
+ tone = gr.Dropdown(
540
+ label="News Tone",
541
+ choices=["serious", "neutral", "funny", "formal", "informal", "urgent"],
542
  value="neutral"
543
  )
544
 
545
  with gr.Column(scale=3):
 
546
  inputs_list = []
547
+ inputs_list.extend([instructions, facts, size, tone])
 
 
 
548
 
549
  with gr.Tabs():
550
+ with gr.TabItem("πŸ“ Documents"):
551
+ documents = []
552
+ for i in range(1, 6):
553
+ doc = gr.File(
554
+ label=f"Document {i}",
555
  file_types=["pdf", "docx", "xlsx", "csv"],
556
+ file_count="single"
 
557
  )
558
+ documents.append(doc)
559
+ inputs_list.append(doc)
560
 
561
  with gr.TabItem("πŸ”Š Audio/Video"):
562
+ for i in range(1, 6):
563
  with gr.Group():
564
+ gr.Markdown(f"**Source {i}**")
565
  file = gr.File(
566
  label=f"Audio/Video {i}",
567
+ file_types=["audio", "video"]
 
568
  )
569
  with gr.Row():
570
+ name = gr.Textbox(
571
+ label="Name",
572
+ placeholder="Interviewee name"
 
573
  )
574
+ position = gr.Textbox(
575
+ label="Position/Role",
576
+ placeholder="Position or role"
 
577
  )
578
+ inputs_list.extend([file, name, position])
 
 
579
 
580
  with gr.TabItem("🌐 URLs"):
581
+ for i in range(1, 6):
582
  url = gr.Textbox(
583
  label=f"URL {i}",
584
+ placeholder="https://..."
 
585
  )
586
  inputs_list.append(url)
587
 
588
+ with gr.TabItem("πŸ“± Social Media"):
589
+ for i in range(1, 4):
590
  with gr.Group():
591
+ gr.Markdown(f"**Social Media {i}**")
592
  social_url = gr.Textbox(
593
+ label="URL",
594
+ placeholder="https://..."
 
595
  )
596
  with gr.Row():
597
+ social_name = gr.Textbox(
598
+ label="Account/Name",
599
+ placeholder="Account or person name"
 
600
  )
601
+ social_context = gr.Textbox(
602
+ label="Context",
603
+ placeholder="Relevant context"
 
604
  )
605
+ inputs_list.extend([social_url, social_name, social_context])
 
 
606
 
607
  with gr.Row():
608
+ generate_btn = gr.Button("✨ Generate News", variant="primary")
609
+ reset_btn = gr.Button("πŸ”„ Clear All")
610
 
611
  with gr.Tabs():
612
+ with gr.TabItem("πŸ“„ Generated News"):
613
+ news_output = gr.Textbox(
614
+ label="News Draft",
615
  lines=15,
616
+ show_copy_button=True
 
617
  )
618
 
619
+ with gr.TabItem("πŸŽ™οΈ Transcriptions"):
620
+ transcriptions_output = gr.Textbox(
621
+ label="Source Transcriptions",
622
  lines=10,
623
+ show_copy_button=True
 
624
  )
625
 
626
+ generate_btn.click(
 
627
  fn=generate_news,
628
  inputs=inputs_list,
629
+ outputs=[news_output, transcriptions_output]
630
  )
631
 
 
632
  def reset_all():
633
+ return [None]*len(inputs_list) + ["", ""]
634
 
635
+ reset_btn.click(
636
  fn=reset_all,
637
  inputs=None,
638
+ outputs=inputs_list + [news_output, transcriptions_output]
639
  )
640
 
641
  return demo
642
 
643
  if __name__ == "__main__":
644
  try:
 
645
  model_manager.initialize_whisper()
646
  except Exception as e:
647
  logger.warning(f"Initial whisper model loading failed: {str(e)}")
648
 
649
  demo = create_demo()
650
+ demo.queue(max_size=5)
651
  demo.launch(
652
  share=True,
653
  server_name="0.0.0.0",