CamiloVega commited on
Commit
3e010de
·
verified ·
1 Parent(s): 2201bd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -233
app.py CHANGED
@@ -6,7 +6,6 @@ import tempfile
6
  import pandas as pd
7
  import requests
8
  from bs4 import BeautifulSoup
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
  import torch
11
  import whisper
12
  from moviepy.editor import VideoFileClip
@@ -16,6 +15,11 @@ import docx
16
  import yt_dlp
17
  from functools import lru_cache
18
  import gc
 
 
 
 
 
19
 
20
  # Configure logging
21
  logging.basicConfig(
@@ -24,6 +28,11 @@ logging.basicConfig(
24
  )
25
  logger = logging.getLogger(__name__)
26
 
 
 
 
 
 
27
  class ModelManager:
28
  _instance = None
29
 
@@ -37,127 +46,126 @@ class ModelManager:
37
  if not self._initialized:
38
  self.tokenizer = None
39
  self.model = None
40
- self.news_generator = None
41
  self.whisper_model = None
42
  self._initialized = True
 
43
 
44
- @spaces.GPU(duration=120)
45
- def initialize_models(self):
46
- """Initialize models with ZeroGPU compatible settings"""
47
  try:
48
- import torch
49
- from transformers import AutoModelForCausalLM, AutoTokenizer
50
 
51
- HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
52
- if not HUGGINGFACE_TOKEN:
53
- raise ValueError("HUGGINGFACE_TOKEN environment variable not set")
54
-
55
- logger.info("Starting model initialization...")
56
- model_name = "meta-llama/Llama-2-7b-chat-hf"
57
-
58
- # Load tokenizer
59
  logger.info("Loading tokenizer...")
60
  self.tokenizer = AutoTokenizer.from_pretrained(
61
- model_name,
62
  token=HUGGINGFACE_TOKEN,
63
  use_fast=True,
64
- model_max_length=512
65
  )
66
  self.tokenizer.pad_token = self.tokenizer.eos_token
67
-
68
- # Initialize model with ZeroGPU compatible settings
69
- logger.info("Loading model...")
70
- self.model = AutoModelForCausalLM.from_pretrained(
71
- model_name,
 
 
 
 
 
 
 
 
72
  token=HUGGINGFACE_TOKEN,
73
- device_map="auto",
74
- torch_dtype=torch.float16,
75
- low_cpu_mem_usage=True,
76
- use_safetensors=True,
77
- # ZeroGPU specific settings
78
- max_memory={0: "6GB"},
79
- offload_folder="offload",
80
- offload_state_dict=True
81
  )
82
-
83
- # Create pipeline with minimal settings
84
- logger.info("Creating pipeline...")
85
- from transformers import pipeline
86
- self.news_generator = pipeline(
87
- "text-generation",
88
- model=self.model,
89
- tokenizer=self.tokenizer,
90
- device_map="auto",
91
- torch_dtype=torch.float16,
92
- max_new_tokens=512,
93
- do_sample=True,
94
- temperature=0.7,
95
- top_p=0.95,
96
- repetition_penalty=1.2,
97
- num_return_sequences=1,
98
- early_stopping=True
99
  )
 
 
 
 
 
 
 
 
100
 
101
- # Load Whisper model with minimal settings
 
 
 
102
  logger.info("Loading Whisper model...")
 
103
  self.whisper_model = whisper.load_model(
104
  "tiny",
105
  device="cuda" if torch.cuda.is_available() else "cpu",
106
  download_root="/tmp/whisper"
107
  )
108
-
109
- logger.info("All models initialized successfully")
110
  return True
111
-
112
  except Exception as e:
113
- logger.error(f"Error during model initialization: {str(e)}")
114
- self.reset_models()
115
  raise
116
 
117
- def reset_models(self):
118
- """Reset all models and clear memory"""
119
- try:
120
- if hasattr(self, 'model') and self.model is not None:
121
- self.model.cpu()
122
- del self.model
123
-
124
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
125
- del self.tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- if hasattr(self, 'news_generator') and self.news_generator is not None:
128
- del self.news_generator
 
 
 
 
 
 
129
 
130
- if hasattr(self, 'whisper_model') and self.whisper_model is not None:
131
- if hasattr(self.whisper_model, 'cpu'):
132
- self.whisper_model.cpu()
133
- del self.whisper_model
134
-
135
- self.tokenizer = None
136
- self.model = None
137
- self.news_generator = None
138
- self.whisper_model = None
139
-
140
- if torch.cuda.is_available():
141
- torch.cuda.empty_cache()
142
- torch.cuda.synchronize()
143
-
144
- import gc
145
- gc.collect()
146
-
147
- except Exception as e:
148
- logger.error(f"Error during model reset: {str(e)}")
149
-
150
- def check_models_initialized(self):
151
- """Check if all models are properly initialized"""
152
- if None in (self.tokenizer, self.model, self.news_generator, self.whisper_model):
153
- logger.warning("Models not initialized, attempting to initialize...")
154
- self.initialize_models()
155
-
156
- def get_models(self):
157
- """Get initialized models, initializing if necessary"""
158
- self.check_models_initialized()
159
- return self.tokenizer, self.model, self.news_generator, self.whisper_model
160
-
161
  # Create global model manager instance
162
  model_manager = ModelManager()
163
 
@@ -188,7 +196,7 @@ def convert_video_to_audio(video_file):
188
  try:
189
  video = VideoFileClip(video_file)
190
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
191
- video.audio.write_audiofile(temp_file.name)
192
  logger.info(f"Video converted to audio: {temp_file.name}")
193
  return temp_file.name
194
  except Exception as e:
@@ -208,30 +216,40 @@ def preprocess_audio(audio_file):
208
  logger.error(f"Error preprocessing audio: {str(e)}")
209
  raise
210
 
211
- @spaces.GPU(duration=120)
212
  def transcribe_audio(file):
213
  """Transcribe an audio or video file."""
214
  try:
215
- _, _, _, whisper_model = model_manager.get_models()
216
 
217
  if isinstance(file, str) and file.startswith('http'):
218
  file_path = download_social_media_video(file)
219
  elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
220
  file_path = convert_video_to_audio(file)
 
 
221
  else:
222
- file_path = preprocess_audio(file)
223
 
224
  logger.info(f"Transcribing audio: {file_path}")
225
  if not os.path.exists(file_path):
226
  raise FileNotFoundError(f"Audio file not found: {file_path}")
227
 
228
  with torch.inference_mode():
229
- result = whisper_model.transcribe(file_path)
230
  if not result:
231
  raise RuntimeError("Transcription failed to produce results")
232
 
233
  transcription = result.get("text", "Error in transcription")
234
  logger.info(f"Transcription completed: {transcription[:50]}...")
 
 
 
 
 
 
 
 
235
  return transcription
236
  except Exception as e:
237
  logger.error(f"Error transcribing: {str(e)}")
@@ -247,7 +265,7 @@ def read_document(document_path):
247
  elif document_path.endswith(".docx"):
248
  doc = docx.Document(document_path)
249
  return "\n".join([paragraph.text for paragraph in doc.paragraphs])
250
- elif document_path.endswith(".xlsx"):
251
  return pd.read_excel(document_path).to_string()
252
  elif document_path.endswith(".csv"):
253
  return pd.read_csv(document_path).to_string()
@@ -260,17 +278,42 @@ def read_document(document_path):
260
  @lru_cache(maxsize=32)
261
  def read_url(url):
262
  """Read the content of a URL."""
 
 
 
263
  try:
264
- response = requests.get(url)
 
 
 
265
  response.raise_for_status()
266
  soup = BeautifulSoup(response.content, 'html.parser')
267
- return soup.get_text()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  except Exception as e:
269
  logger.error(f"Error reading URL: {str(e)}")
270
  return f"Error reading URL: {str(e)}"
271
 
272
  def process_social_content(url):
273
  """Process social media content."""
 
 
 
274
  try:
275
  text_content = read_url(url)
276
  try:
@@ -287,11 +330,20 @@ def process_social_content(url):
287
  logger.error(f"Error processing social content: {str(e)}")
288
  return None
289
 
290
- @spaces.GPU(duration=120)
291
  def generate_news(instructions, facts, size, tone, *args):
 
292
  try:
293
- tokenizer, _, news_generator, _ = model_manager.get_models()
 
 
 
 
 
 
 
294
 
 
295
  knowledge_base = {
296
  "instructions": instructions,
297
  "facts": facts,
@@ -301,78 +353,123 @@ def generate_news(instructions, facts, size, tone, *args):
301
  "social_content": []
302
  }
303
 
 
304
  num_audios = 5 * 3
305
  num_social_urls = 3 * 3
306
  num_urls = 5
307
 
 
308
  audios = args[:num_audios]
309
  social_urls = args[num_audios:num_audios+num_social_urls]
310
  urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
311
  documents = args[num_audios+num_social_urls+num_urls:]
312
 
 
 
313
  for url in urls:
314
- if url:
315
  content = read_url(url)
316
  if content and not content.startswith("Error"):
317
  knowledge_base["url_content"].append(content)
318
 
 
 
319
  for document in documents:
320
  if document is not None:
321
  content = read_document(document.name)
322
  if content and not content.startswith("Error"):
323
  knowledge_base["document_content"].append(content)
324
 
 
 
325
  for i in range(0, len(audios), 3):
326
- audio_file, name, position = audios[i:i+3]
327
- if audio_file is not None:
328
- knowledge_base["audio_data"].append({
329
- "audio": audio_file,
330
- "name": name,
331
- "position": position
332
- })
333
-
334
- for i in range(0, len(social_urls), 3):
335
- social_url, social_name, social_context = social_urls[i:i+3]
336
- if social_url:
337
- social_content = process_social_content(social_url)
338
- if social_content:
339
- knowledge_base["social_content"].append({
340
- "url": social_url,
341
- "name": social_name,
342
- "context": social_context,
343
- "text": social_content["text"],
344
- "video": social_content["video"]
345
  })
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  transcriptions_text = ""
348
  raw_transcriptions = ""
349
 
 
 
350
  for idx, data in enumerate(knowledge_base["audio_data"]):
351
  if data["audio"] is not None:
352
  transcription = transcribe_audio(data["audio"])
353
- if not transcription.startswith("Error"):
354
- transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n'
355
  raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
356
 
357
- for data in knowledge_base["social_content"]:
 
358
  if data["text"] and not str(data["text"]).startswith("Error"):
359
- transcriptions_text += f'[Social media text]: "{data["text"][:200]}..." - {data["name"]}, {data["context"]}\n'
360
- raw_transcriptions += transcriptions_text + "\n\n"
 
 
 
 
361
  if data["video"] and not str(data["video"]).startswith("Error"):
362
- video_transcription = f'[Social media video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n'
363
  transcriptions_text += video_transcription
364
- raw_transcriptions += video_transcription + "\n\n"
365
-
366
- document_content = "\n\n".join(knowledge_base["document_content"])
367
- url_content = "\n\n".join(knowledge_base["url_content"])
368
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
- prompt = f"""[INST] You are a professional news writer. Write a news article based on the following information:
 
371
 
372
  Instructions: {knowledge_base["instructions"]}
 
373
  Facts: {knowledge_base["facts"]}
374
- Additional content from documents: {document_content}
375
- Additional content from URLs: {url_content}
 
 
 
 
376
 
377
  Use these transcriptions as direct and indirect quotes:
378
  {transcriptions_text}
@@ -380,7 +477,7 @@ Use these transcriptions as direct and indirect quotes:
380
  Follow these requirements:
381
  - Write a title
382
  - Write a 15-word hook that complements the title
383
- - Write the body with {size} words
384
  - Use a {tone} tone
385
  - Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
386
  - Use at least 80% direct quotes (in quotation marks)
@@ -388,27 +485,41 @@ Follow these requirements:
388
  - Do not invent information
389
  - Be rigorous with the provided facts [/INST]"""
390
 
391
- # Optimize size and max tokens
392
- max_tokens = min(int(size * 1.5), 512)
393
 
394
- # Generate article with optimized settings
395
  with torch.inference_mode():
396
  try:
397
- news_article = news_generator(
398
- prompt,
399
- max_new_tokens=max_tokens,
400
- num_return_sequences=1,
 
 
 
 
 
 
 
 
401
  do_sample=True,
402
  temperature=0.7,
403
  top_p=0.95,
404
  repetition_penalty=1.2,
405
- early_stopping=True
 
 
 
 
 
 
 
406
  )
407
 
408
- # Process the generated text
409
- if isinstance(news_article, list):
410
- news_article = news_article[0]['generated_text']
411
- news_article = news_article.replace('[INST]', '').replace('[/INST]', '').strip()
412
 
413
  except Exception as gen_error:
414
  logger.error(f"Error in text generation: {str(gen_error)}")
@@ -419,122 +530,173 @@ Follow these requirements:
419
  except Exception as e:
420
  logger.error(f"Error generating news: {str(e)}")
421
  try:
422
- # Attempt to recover by resetting and reinitializing models
423
- model_manager.reset_models()
424
- model_manager.initialize_models()
425
- logger.info("Models reinitialized successfully after error")
426
- except Exception as reinit_error:
427
- logger.error(f"Failed to reinitialize models: {str(reinit_error)}")
428
- return f"Error generating the news article: {str(e)}", ""
429
 
430
  def create_demo():
431
- with gr.Blocks() as demo:
432
- gr.Markdown("## Generador de noticias todo en uno")
 
433
 
434
  with gr.Row():
435
  with gr.Column(scale=2):
436
  instrucciones = gr.Textbox(
437
  label="Instrucciones para la noticia",
 
438
  lines=2
439
  )
440
  hechos = gr.Textbox(
441
- label="Describe los hechos de la noticia",
 
442
  lines=4
443
  )
444
- tamaño = gr.Number(
445
- label="Tamaño del cuerpo de la noticia (en palabras)",
446
- value=100
447
- )
448
- tono = gr.Dropdown(
449
- label="Tono de la noticia",
450
- choices=["serio", "neutral", "divertido"],
451
- value="neutral"
452
- )
 
 
 
 
 
453
 
454
  with gr.Column(scale=3):
455
  inputs_list = [instrucciones, hechos, tamaño, tono]
456
 
457
  with gr.Tabs():
458
- for i in range(1, 6):
459
- with gr.TabItem(f"Audio/Video {i}"):
460
- file = gr.File(
461
- label=f"Audio/Video {i}",
462
- file_types=["audio", "video"]
463
- )
464
- nombre = gr.Textbox(
465
- label="Nombre",
466
- placeholder="Nombre del entrevistado"
467
- )
468
- cargo = gr.Textbox(
469
- label="Cargo",
470
- placeholder="Cargo o rol"
471
- )
472
- inputs_list.extend([file, nombre, cargo])
473
-
474
- for i in range(1, 4):
475
- with gr.TabItem(f"Red Social {i}"):
476
- social_url = gr.Textbox(
477
- label=f"URL de red social {i}",
478
- placeholder="https://..."
479
- )
480
- social_nombre = gr.Textbox(
481
- label=f"Nombre de persona/cuenta {i}"
482
- )
483
- social_contexto = gr.Textbox(
484
- label=f"Contexto del contenido {i}",
485
- lines=2
486
- )
487
- inputs_list.extend([social_url, social_nombre, social_contexto])
488
-
489
- for i in range(1, 6):
490
- with gr.TabItem(f"URL {i}"):
 
 
 
 
 
 
491
  url = gr.Textbox(
492
  label=f"URL {i}",
493
  placeholder="https://..."
494
  )
495
  inputs_list.append(url)
496
-
497
- for i in range(1, 6):
498
- with gr.TabItem(f"Documento {i}"):
499
- documento = gr.File(
500
- label=f"Documento {i}",
501
- file_types=["pdf", "docx", "xlsx", "csv"],
502
- file_count="single"
503
- )
504
- inputs_list.append(documento)
505
-
506
- gr.Markdown("---")
507
-
508
- with gr.Row():
509
- transcripciones_output = gr.Textbox(
510
- label="Transcripciones",
511
- lines=10,
512
- show_copy_button=True
513
- )
514
-
515
- gr.Markdown("---")
 
 
 
 
 
 
 
516
 
517
  with gr.Row():
518
- generar = gr.Button("Generar borrador")
519
-
520
- with gr.Row():
521
- noticia_output = gr.Textbox(
522
- label="Borrador generado",
523
- lines=20,
524
- show_copy_button=True
525
- )
 
 
 
 
 
 
 
 
 
526
 
 
527
  generar.click(
528
  fn=generate_news,
529
  inputs=inputs_list,
530
  outputs=[noticia_output, transcripciones_output]
531
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  return demo
534
 
535
  if __name__ == "__main__":
 
 
 
 
 
 
 
536
  demo = create_demo()
537
- demo.queue()
538
  demo.launch(
539
  share=True,
540
  server_name="0.0.0.0",
 
6
  import pandas as pd
7
  import requests
8
  from bs4 import BeautifulSoup
 
9
  import torch
10
  import whisper
11
  from moviepy.editor import VideoFileClip
 
15
  import yt_dlp
16
  from functools import lru_cache
17
  import gc
18
+ import time
19
+ from huggingface_hub import login
20
+ from transformers import AutoTokenizer, BitsAndBytesConfig
21
+ from unsloth import FastLanguageModel
22
+ import tqdm
23
 
24
  # Configure logging
25
  logging.basicConfig(
 
28
  )
29
  logger = logging.getLogger(__name__)
30
 
31
+ # Login to Hugging Face Hub if token is available
32
+ HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
33
+ if HUGGINGFACE_TOKEN:
34
+ login(token=HUGGINGFACE_TOKEN)
35
+
36
  class ModelManager:
37
  _instance = None
38
 
 
46
  if not self._initialized:
47
  self.tokenizer = None
48
  self.model = None
 
49
  self.whisper_model = None
50
  self._initialized = True
51
+ self.last_used = time.time()
52
 
53
+ @spaces.GPU()
54
+ def initialize_llm(self):
55
+ """Initialize LLM model with unsloth optimization"""
56
  try:
57
+ MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
 
58
 
 
 
 
 
 
 
 
 
59
  logger.info("Loading tokenizer...")
60
  self.tokenizer = AutoTokenizer.from_pretrained(
61
+ MODEL_NAME,
62
  token=HUGGINGFACE_TOKEN,
63
  use_fast=True,
 
64
  )
65
  self.tokenizer.pad_token = self.tokenizer.eos_token
66
+
67
+ # Configure 4-bit quantization
68
+ bnb_config = BitsAndBytesConfig(
69
+ load_in_4bit=True,
70
+ bnb_4bit_quant_type="nf4",
71
+ bnb_4bit_compute_dtype=torch.float16,
72
+ bnb_4bit_use_double_quant=True
73
+ )
74
+
75
+ logger.info("Loading and optimizing model with unsloth...")
76
+ # Use unsloth to load and optimize the model
77
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
78
+ model_name=MODEL_NAME,
79
  token=HUGGINGFACE_TOKEN,
80
+ quantization_config=bnb_config,
81
+ max_seq_length=2048,
82
+ device_map="auto"
 
 
 
 
 
83
  )
84
+
85
+ # Optimize with unsloth
86
+ self.model = FastLanguageModel.get_peft_model(
87
+ self.model,
88
+ r=16,
89
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
90
+ "gate_proj", "up_proj", "down_proj"],
91
+ lora_alpha=16,
92
+ lora_dropout=0,
93
+ bias="none",
94
+ use_gradient_checkpointing=True,
95
+ random_state=3407
 
 
 
 
 
96
  )
97
+
98
+ logger.info("LLM initialized successfully")
99
+ self.last_used = time.time()
100
+ return True
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error initializing LLM: {str(e)}")
104
+ raise
105
 
106
+ @spaces.GPU()
107
+ def initialize_whisper(self):
108
+ """Initialize Whisper model for audio transcription"""
109
+ try:
110
  logger.info("Loading Whisper model...")
111
+ # Using tiny model for efficiency but can be changed based on needs
112
  self.whisper_model = whisper.load_model(
113
  "tiny",
114
  device="cuda" if torch.cuda.is_available() else "cpu",
115
  download_root="/tmp/whisper"
116
  )
117
+ logger.info("Whisper model initialized successfully")
118
+ self.last_used = time.time()
119
  return True
 
120
  except Exception as e:
121
+ logger.error(f"Error initializing Whisper: {str(e)}")
 
122
  raise
123
 
124
+ def check_llm_initialized(self):
125
+ """Check if LLM is initialized and initialize if needed"""
126
+ if self.tokenizer is None or self.model is None:
127
+ logger.info("LLM not initialized, initializing...")
128
+ self.initialize_llm()
129
+ self.last_used = time.time()
130
+
131
+ def check_whisper_initialized(self):
132
+ """Check if Whisper model is initialized and initialize if needed"""
133
+ if self.whisper_model is None:
134
+ logger.info("Whisper model not initialized, initializing...")
135
+ self.initialize_whisper()
136
+ self.last_used = time.time()
137
+
138
+ def reset_models(self, force=False):
139
+ """Reset models to free memory if they haven't been used recently"""
140
+ current_time = time.time()
141
+ # Only reset if forced or models haven't been used for 10 minutes
142
+ if force or (current_time - self.last_used > 600):
143
+ try:
144
+ logger.info("Resetting models to free memory...")
145
 
146
+ if hasattr(self, 'model') and self.model is not None:
147
+ del self.model
148
+
149
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
150
+ del self.tokenizer
151
+
152
+ if hasattr(self, 'whisper_model') and self.whisper_model is not None:
153
+ del self.whisper_model
154
 
155
+ self.tokenizer = None
156
+ self.model = None
157
+ self.whisper_model = None
158
+
159
+ if torch.cuda.is_available():
160
+ torch.cuda.empty_cache()
161
+ torch.cuda.synchronize()
162
+
163
+ gc.collect()
164
+ logger.info("Models reset successfully")
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error resetting models: {str(e)}")
168
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # Create global model manager instance
170
  model_manager = ModelManager()
171
 
 
196
  try:
197
  video = VideoFileClip(video_file)
198
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
199
+ video.audio.write_audiofile(temp_file.name, verbose=False, logger=None)
200
  logger.info(f"Video converted to audio: {temp_file.name}")
201
  return temp_file.name
202
  except Exception as e:
 
216
  logger.error(f"Error preprocessing audio: {str(e)}")
217
  raise
218
 
219
+ @spaces.GPU()
220
  def transcribe_audio(file):
221
  """Transcribe an audio or video file."""
222
  try:
223
+ model_manager.check_whisper_initialized()
224
 
225
  if isinstance(file, str) and file.startswith('http'):
226
  file_path = download_social_media_video(file)
227
  elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
228
  file_path = convert_video_to_audio(file)
229
+ elif file is not None: # Handle file object from Gradio
230
+ file_path = preprocess_audio(file.name)
231
  else:
232
+ return "" # Return empty string for None input
233
 
234
  logger.info(f"Transcribing audio: {file_path}")
235
  if not os.path.exists(file_path):
236
  raise FileNotFoundError(f"Audio file not found: {file_path}")
237
 
238
  with torch.inference_mode():
239
+ result = model_manager.whisper_model.transcribe(file_path)
240
  if not result:
241
  raise RuntimeError("Transcription failed to produce results")
242
 
243
  transcription = result.get("text", "Error in transcription")
244
  logger.info(f"Transcription completed: {transcription[:50]}...")
245
+
246
+ # Clean up temp file
247
+ try:
248
+ if os.path.exists(file_path):
249
+ os.remove(file_path)
250
+ except Exception as e:
251
+ logger.warning(f"Could not remove temp file {file_path}: {str(e)}")
252
+
253
  return transcription
254
  except Exception as e:
255
  logger.error(f"Error transcribing: {str(e)}")
 
265
  elif document_path.endswith(".docx"):
266
  doc = docx.Document(document_path)
267
  return "\n".join([paragraph.text for paragraph in doc.paragraphs])
268
+ elif document_path.endswith((".xlsx", ".xls")):
269
  return pd.read_excel(document_path).to_string()
270
  elif document_path.endswith(".csv"):
271
  return pd.read_csv(document_path).to_string()
 
278
  @lru_cache(maxsize=32)
279
  def read_url(url):
280
  """Read the content of a URL."""
281
+ if not url or url.strip() == "":
282
+ return ""
283
+
284
  try:
285
+ headers = {
286
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
287
+ }
288
+ response = requests.get(url, headers=headers, timeout=15)
289
  response.raise_for_status()
290
  soup = BeautifulSoup(response.content, 'html.parser')
291
+
292
+ # Remove non-content elements
293
+ for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]):
294
+ element.extract()
295
+
296
+ # Extract main content
297
+ main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"])
298
+ if main_content:
299
+ text = main_content.get_text(separator='\n', strip=True)
300
+ else:
301
+ text = soup.get_text(separator='\n', strip=True)
302
+
303
+ # Clean up whitespace
304
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
305
+ text = '\n'.join(lines)
306
+
307
+ return text[:10000] # Limit to 10k chars to avoid huge inputs
308
  except Exception as e:
309
  logger.error(f"Error reading URL: {str(e)}")
310
  return f"Error reading URL: {str(e)}"
311
 
312
  def process_social_content(url):
313
  """Process social media content."""
314
+ if not url or url.strip() == "":
315
+ return None
316
+
317
  try:
318
  text_content = read_url(url)
319
  try:
 
330
  logger.error(f"Error processing social content: {str(e)}")
331
  return None
332
 
333
+ @spaces.GPU()
334
  def generate_news(instructions, facts, size, tone, *args):
335
+ """Generate a news article based on provided data"""
336
  try:
337
+ # Ensure size is integer
338
+ if isinstance(size, float):
339
+ size = int(size)
340
+ elif not isinstance(size, int):
341
+ size = 250 # Default size
342
+
343
+ # Check if models are initialized
344
+ model_manager.check_llm_initialized()
345
 
346
+ # Prepare data structure for inputs
347
  knowledge_base = {
348
  "instructions": instructions,
349
  "facts": facts,
 
353
  "social_content": []
354
  }
355
 
356
+ # Define the indices for parsing args
357
  num_audios = 5 * 3
358
  num_social_urls = 3 * 3
359
  num_urls = 5
360
 
361
+ # Parse arguments
362
  audios = args[:num_audios]
363
  social_urls = args[num_audios:num_audios+num_social_urls]
364
  urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
365
  documents = args[num_audios+num_social_urls+num_urls:]
366
 
367
+ # Process URLs with progress reporting
368
+ logger.info("Processing URLs...")
369
  for url in urls:
370
+ if url and isinstance(url, str) and url.strip():
371
  content = read_url(url)
372
  if content and not content.startswith("Error"):
373
  knowledge_base["url_content"].append(content)
374
 
375
+ # Process documents
376
+ logger.info("Processing documents...")
377
  for document in documents:
378
  if document is not None:
379
  content = read_document(document.name)
380
  if content and not content.startswith("Error"):
381
  knowledge_base["document_content"].append(content)
382
 
383
+ # Process audio/video files
384
+ logger.info("Processing audio/video files...")
385
  for i in range(0, len(audios), 3):
386
+ if i+2 < len(audios): # Ensure we have complete set of 3 elements
387
+ audio_file, name, position = audios[i:i+3]
388
+ if audio_file is not None:
389
+ knowledge_base["audio_data"].append({
390
+ "audio": audio_file,
391
+ "name": name or "Unknown",
392
+ "position": position or "Not specified"
 
 
 
 
 
 
 
 
 
 
 
 
393
  })
394
 
395
+ # Process social media content
396
+ logger.info("Processing social media content...")
397
+ for i in range(0, len(social_urls), 3):
398
+ if i+2 < len(social_urls): # Ensure we have complete set of 3 elements
399
+ social_url, social_name, social_context = social_urls[i:i+3]
400
+ if social_url and isinstance(social_url, str) and social_url.strip():
401
+ social_content = process_social_content(social_url)
402
+ if social_content:
403
+ knowledge_base["social_content"].append({
404
+ "url": social_url,
405
+ "name": social_name or "Unknown",
406
+ "context": social_context or "Not specified",
407
+ "text": social_content.get("text", ""),
408
+ "video": social_content.get("video", "")
409
+ })
410
+
411
+ # Prepare transcriptions text
412
  transcriptions_text = ""
413
  raw_transcriptions = ""
414
 
415
+ # Process audio data transcriptions
416
+ logger.info("Transcribing audio...")
417
  for idx, data in enumerate(knowledge_base["audio_data"]):
418
  if data["audio"] is not None:
419
  transcription = transcribe_audio(data["audio"])
420
+ if transcription and not transcription.startswith("Error"):
421
+ transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n'
422
  raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
423
 
424
+ # Process social media content transcriptions
425
+ for idx, data in enumerate(knowledge_base["social_content"]):
426
  if data["text"] and not str(data["text"]).startswith("Error"):
427
+ # Truncate long texts for the prompt
428
+ text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"]
429
+ social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n'
430
+ transcriptions_text += social_text
431
+ raw_transcriptions += social_text
432
+
433
  if data["video"] and not str(data["video"]).startswith("Error"):
434
+ video_transcription = f'[Social media {idx+1} - video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n\n'
435
  transcriptions_text += video_transcription
436
+ raw_transcriptions += video_transcription
437
+
438
+ # Combine document content and URL content (with truncation for very long content)
439
+ document_summaries = []
440
+ for idx, doc in enumerate(knowledge_base["document_content"]):
441
+ # Truncate long documents
442
+ if len(doc) > 1000:
443
+ doc_excerpt = doc[:1000] + "... [document continues]"
444
+ else:
445
+ doc_excerpt = doc
446
+ document_summaries.append(f"[Document {idx+1}]: {doc_excerpt}")
447
+
448
+ document_content = "\n\n".join(document_summaries)
449
+
450
+ url_summaries = []
451
+ for idx, url_content in enumerate(knowledge_base["url_content"]):
452
+ # Truncate long URL content
453
+ if len(url_content) > 1000:
454
+ url_excerpt = url_content[:1000] + "... [content continues]"
455
+ else:
456
+ url_excerpt = url_content
457
+ url_summaries.append(f"[URL {idx+1}]: {url_excerpt}")
458
+
459
+ url_content = "\n\n".join(url_summaries)
460
 
461
+ # Create prompt for the model
462
+ prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information:
463
 
464
  Instructions: {knowledge_base["instructions"]}
465
+
466
  Facts: {knowledge_base["facts"]}
467
+
468
+ Additional content from documents:
469
+ {document_content}
470
+
471
+ Additional content from URLs:
472
+ {url_content}
473
 
474
  Use these transcriptions as direct and indirect quotes:
475
  {transcriptions_text}
 
477
  Follow these requirements:
478
  - Write a title
479
  - Write a 15-word hook that complements the title
480
+ - Write the body with approximately {size} words
481
  - Use a {tone} tone
482
  - Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
483
  - Use at least 80% direct quotes (in quotation marks)
 
485
  - Do not invent information
486
  - Be rigorous with the provided facts [/INST]"""
487
 
488
+ # Optimize for requested size
489
+ max_new_tokens = min(int(size * 2.5), 1024) # Increased limit for better quality
490
 
491
+ # Generate response using optimized unsloth model
492
  with torch.inference_mode():
493
  try:
494
+ logger.info("Generating news article...")
495
+ # Use unsloth's optimized generate method
496
+ inputs = model_manager.tokenizer(
497
+ prompt,
498
+ return_tensors="pt",
499
+ add_special_tokens=False
500
+ ).to(model_manager.model.device)
501
+
502
+ # Generate with optimized settings
503
+ outputs = model_manager.model.generate(
504
+ **inputs,
505
+ max_new_tokens=max_new_tokens,
506
  do_sample=True,
507
  temperature=0.7,
508
  top_p=0.95,
509
  repetition_penalty=1.2,
510
+ pad_token_id=model_manager.tokenizer.eos_token_id,
511
+ use_cache=True
512
+ )
513
+
514
+ # Decode the generated text
515
+ generated_text = model_manager.tokenizer.decode(
516
+ outputs[0][inputs.input_ids.shape[1]:],
517
+ skip_special_tokens=True
518
  )
519
 
520
+ # Clean up the generated text
521
+ news_article = generated_text.strip()
522
+ logger.info(f"News generation completed: {len(news_article)} chars")
 
523
 
524
  except Exception as gen_error:
525
  logger.error(f"Error in text generation: {str(gen_error)}")
 
530
  except Exception as e:
531
  logger.error(f"Error generating news: {str(e)}")
532
  try:
533
+ # Reset models to recover from errors
534
+ model_manager.reset_models(force=True)
535
+ except Exception as reset_error:
536
+ logger.error(f"Failed to reset models: {str(reset_error)}")
537
+ return f"Error generando la noticia: {str(e)}", "Error procesando las transcripciones."
 
 
538
 
539
  def create_demo():
540
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
541
+ gr.Markdown("# 📰 NewsIA - Generador de Noticias IA")
542
+ gr.Markdown("Crea noticias profesionales a partir de múltiples fuentes de información.")
543
 
544
  with gr.Row():
545
  with gr.Column(scale=2):
546
  instrucciones = gr.Textbox(
547
  label="Instrucciones para la noticia",
548
+ placeholder="Escribe instrucciones específicas para la generación de tu noticia",
549
  lines=2
550
  )
551
  hechos = gr.Textbox(
552
+ label="Hechos principales",
553
+ placeholder="Describe los hechos más importantes que debe incluir la noticia",
554
  lines=4
555
  )
556
+
557
+ with gr.Row():
558
+ tamaño = gr.Slider(
559
+ label="Longitud aproximada (palabras)",
560
+ minimum=100,
561
+ maximum=500,
562
+ value=250,
563
+ step=50
564
+ )
565
+ tono = gr.Dropdown(
566
+ label="Tono de la noticia",
567
+ choices=["serio", "neutral", "divertido", "formal", "informal", "urgente"],
568
+ value="neutral"
569
+ )
570
 
571
  with gr.Column(scale=3):
572
  inputs_list = [instrucciones, hechos, tamaño, tono]
573
 
574
  with gr.Tabs():
575
+ with gr.TabItem("📝 Documentos"):
576
+ for i in range(1, 4): # Reduced to 3 for better UX
577
+ with gr.Row():
578
+ documento = gr.File(
579
+ label=f"Documento {i}",
580
+ file_types=["pdf", "docx", "xlsx", "csv"],
581
+ file_count="single"
582
+ )
583
+ inputs_list.append(documento)
584
+
585
+ # Add empty inputs to match the original expected array length
586
+ for i in range(4, 6):
587
+ inputs_list.append(None)
588
+
589
+ with gr.TabItem("🔊 Audio/Video"):
590
+ for i in range(1, 4): # Reduced to 3 for better UX
591
+ with gr.Group():
592
+ gr.Markdown(f"**Fuente {i}**")
593
+ file = gr.File(
594
+ label=f"Audio/Video {i}",
595
+ file_types=["audio", "video"]
596
+ )
597
+ with gr.Row():
598
+ nombre = gr.Textbox(
599
+ label="Nombre",
600
+ placeholder="Nombre del entrevistado"
601
+ )
602
+ cargo = gr.Textbox(
603
+ label="Cargo/Rol",
604
+ placeholder="Cargo o rol"
605
+ )
606
+ inputs_list.extend([file, nombre, cargo])
607
+
608
+ # Add empty inputs to match the original expected array length
609
+ for i in range(4, 6):
610
+ inputs_list.extend([None, None, None])
611
+
612
+ with gr.TabItem("🌐 URLs"):
613
+ for i in range(1, 4): # Reduced to 3 for better UX
614
  url = gr.Textbox(
615
  label=f"URL {i}",
616
  placeholder="https://..."
617
  )
618
  inputs_list.append(url)
619
+
620
+ # Add empty inputs to match the original expected array length
621
+ for i in range(4, 6):
622
+ inputs_list.append(None)
623
+
624
+ with gr.TabItem("📱 Redes Sociales"):
625
+ for i in range(1, 3): # Reduced to 2 for better UX
626
+ with gr.Group():
627
+ gr.Markdown(f"**Red Social {i}**")
628
+ social_url = gr.Textbox(
629
+ label=f"URL",
630
+ placeholder="https://..."
631
+ )
632
+ with gr.Row():
633
+ social_nombre = gr.Textbox(
634
+ label=f"Nombre/Cuenta",
635
+ placeholder="Nombre de la persona o cuenta"
636
+ )
637
+ social_contexto = gr.Textbox(
638
+ label=f"Contexto",
639
+ placeholder="Contexto relevante"
640
+ )
641
+ inputs_list.extend([social_url, social_nombre, social_contexto])
642
+
643
+ # Add empty inputs to match the original expected array length
644
+ for i in range(3, 4):
645
+ inputs_list.extend([None, None, None])
646
 
647
  with gr.Row():
648
+ generar = gr.Button("Generar Noticia", variant="primary")
649
+ reset = gr.Button("🔄 Limpiar Todo")
650
+
651
+ with gr.Tabs():
652
+ with gr.TabItem("📄 Noticia Generada"):
653
+ noticia_output = gr.Textbox(
654
+ label="Borrador de la noticia",
655
+ lines=15,
656
+ show_copy_button=True
657
+ )
658
+
659
+ with gr.TabItem("🎙️ Transcripciones"):
660
+ transcripciones_output = gr.Textbox(
661
+ label="Transcripciones de fuentes",
662
+ lines=10,
663
+ show_copy_button=True
664
+ )
665
 
666
+ # Set up event handlers
667
  generar.click(
668
  fn=generate_news,
669
  inputs=inputs_list,
670
  outputs=[noticia_output, transcripciones_output]
671
  )
672
+
673
+ # Reset functionality to clear all inputs
674
+ def reset_all():
675
+ output = []
676
+ for _ in range(len(inputs_list)):
677
+ output.append(None)
678
+ output.append("")
679
+ output.append("")
680
+ return output
681
+
682
+ reset.click(
683
+ fn=reset_all,
684
+ inputs=[],
685
+ outputs=inputs_list + [noticia_output, transcripciones_output]
686
+ )
687
 
688
  return demo
689
 
690
  if __name__ == "__main__":
691
+ # Initialize models on startup to reduce first request latency
692
+ try:
693
+ model_manager.initialize_whisper()
694
+ model_manager.initialize_llm()
695
+ except Exception as e:
696
+ logger.warning(f"Initial model loading failed: {str(e)}")
697
+
698
  demo = create_demo()
699
+ demo.queue(concurrency_count=1, max_size=5)
700
  demo.launch(
701
  share=True,
702
  server_name="0.0.0.0",