NewsIA / app.py
CamiloVega's picture
Update app.py
4e5d878 verified
raw
history blame
28.6 kB
import spaces
import gradio as gr
import logging
import os
import tempfile
import pandas as pd
import requests
from bs4 import BeautifulSoup
import torch
import whisper
import subprocess
from pydub import AudioSegment
import fitz
import docx
import yt_dlp
from functools import lru_cache
import gc
import time
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Login to Hugging Face Hub if token is available
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
if HUGGINGFACE_TOKEN:
login(token=HUGGINGFACE_TOKEN)
class ModelManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.tokenizer = None
self.model = None
self.pipeline = None
self.whisper_model = None
self._initialized = True
self.last_used = time.time()
@spaces.GPU()
def initialize_llm(self):
"""Initialize LLM model with standard transformers"""
try:
# Use small model for ZeroGPU compatibility
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HUGGINGFACE_TOKEN,
use_fast=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Basic memory settings for ZeroGPU
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HUGGINGFACE_TOKEN,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
# Optimizations for ZeroGPU
max_memory={0: "4GB"},
offload_folder="offload",
offload_state_dict=True
)
# Create text generation pipeline
logger.info("Creating pipeline...")
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
torch_dtype=torch.float16,
device_map="auto",
max_length=1024
)
logger.info("LLM initialized successfully")
self.last_used = time.time()
return True
except Exception as e:
logger.error(f"Error initializing LLM: {str(e)}")
raise
@spaces.GPU()
def initialize_whisper(self):
"""Initialize Whisper model for audio transcription"""
try:
logger.info("Loading Whisper model...")
# Using tiny model for efficiency but can be changed based on needs
self.whisper_model = whisper.load_model(
"tiny",
device="cuda" if torch.cuda.is_available() else "cpu",
download_root="/tmp/whisper"
)
logger.info("Whisper model initialized successfully")
self.last_used = time.time()
return True
except Exception as e:
logger.error(f"Error initializing Whisper: {str(e)}")
raise
def check_llm_initialized(self):
"""Check if LLM is initialized and initialize if needed"""
if self.tokenizer is None or self.model is None or self.pipeline is None:
logger.info("LLM not initialized, initializing...")
self.initialize_llm()
self.last_used = time.time()
def check_whisper_initialized(self):
"""Check if Whisper model is initialized and initialize if needed"""
if self.whisper_model is None:
logger.info("Whisper model not initialized, initializing...")
self.initialize_whisper()
self.last_used = time.time()
def reset_models(self, force=False):
"""Reset models to free memory if they haven't been used recently"""
current_time = time.time()
# Only reset if forced or models haven't been used for 10 minutes
if force or (current_time - self.last_used > 600):
try:
logger.info("Resetting models to free memory...")
if hasattr(self, 'model') and self.model is not None:
del self.model
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
if hasattr(self, 'pipeline') and self.pipeline is not None:
del self.pipeline
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
del self.whisper_model
self.tokenizer = None
self.model = None
self.pipeline = None
self.whisper_model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("Models reset successfully")
except Exception as e:
logger.error(f"Error resetting models: {str(e)}")
# Create global model manager instance
model_manager = ModelManager()
@lru_cache(maxsize=32)
def download_social_media_video(url):
"""Download a video from social media."""
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': '%(id)s.%(ext)s',
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(url, download=True)
audio_file = f"{info_dict['id']}.mp3"
logger.info(f"Video downloaded successfully: {audio_file}")
return audio_file
except Exception as e:
logger.error(f"Error downloading video: {str(e)}")
raise
def convert_video_to_audio(video_file):
"""Convert a video file to audio using ffmpeg directly."""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
output_file = temp_file.name
# Use ffmpeg directly via subprocess
command = [
"ffmpeg",
"-i", video_file,
"-q:a", "0",
"-map", "a",
"-vn",
output_file,
"-y" # Overwrite output file if it exists
]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info(f"Video converted to audio: {output_file}")
return output_file
except Exception as e:
logger.error(f"Error converting video: {str(e)}")
raise
def preprocess_audio(audio_file):
"""Preprocess the audio file to improve quality."""
try:
audio = AudioSegment.from_file(audio_file)
audio = audio.apply_gain(-audio.dBFS + (-20))
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
audio.export(temp_file.name, format="mp3")
logger.info(f"Audio preprocessed: {temp_file.name}")
return temp_file.name
except Exception as e:
logger.error(f"Error preprocessing audio: {str(e)}")
raise
@spaces.GPU()
def transcribe_audio(file):
"""Transcribe an audio or video file."""
try:
model_manager.check_whisper_initialized()
if isinstance(file, str) and file.startswith('http'):
file_path = download_social_media_video(file)
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
file_path = convert_video_to_audio(file)
elif file is not None: # Handle file object from Gradio
file_path = preprocess_audio(file.name)
else:
return "" # Return empty string for None input
logger.info(f"Transcribing audio: {file_path}")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
with torch.inference_mode():
result = model_manager.whisper_model.transcribe(file_path)
if not result:
raise RuntimeError("Transcription failed to produce results")
transcription = result.get("text", "Error in transcription")
logger.info(f"Transcription completed: {transcription[:50]}...")
# Clean up temp file
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
logger.warning(f"Could not remove temp file {file_path}: {str(e)}")
return transcription
except Exception as e:
logger.error(f"Error transcribing: {str(e)}")
return f"Error processing the file: {str(e)}"
@lru_cache(maxsize=32)
def read_document(document_path):
"""Read the content of a document."""
try:
if document_path.endswith(".pdf"):
doc = fitz.open(document_path)
return "\n".join([page.get_text() for page in doc])
elif document_path.endswith(".docx"):
doc = docx.Document(document_path)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
elif document_path.endswith((".xlsx", ".xls")):
return pd.read_excel(document_path).to_string()
elif document_path.endswith(".csv"):
return pd.read_csv(document_path).to_string()
else:
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
except Exception as e:
logger.error(f"Error reading document: {str(e)}")
return f"Error reading document: {str(e)}"
@lru_cache(maxsize=32)
def read_url(url):
"""Read the content of a URL."""
if not url or url.strip() == "":
return ""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=15)
response.raise_for_status()
soup = BeautifulSoup(response.content, 'html.parser')
# Remove non-content elements
for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]):
element.extract()
# Extract main content
main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"])
if main_content:
text = main_content.get_text(separator='\n', strip=True)
else:
text = soup.get_text(separator='\n', strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.split('\n') if line.strip()]
text = '\n'.join(lines)
return text[:10000] # Limit to 10k chars to avoid huge inputs
except Exception as e:
logger.error(f"Error reading URL: {str(e)}")
return f"Error reading URL: {str(e)}"
def process_social_content(url):
"""Process social media content."""
if not url or url.strip() == "":
return None
try:
text_content = read_url(url)
try:
video_content = transcribe_audio(url)
except Exception as e:
logger.error(f"Error processing video content: {str(e)}")
video_content = None
return {
"text": text_content,
"video": video_content
}
except Exception as e:
logger.error(f"Error processing social content: {str(e)}")
return None
@spaces.GPU()
def generate_news(instructions, facts, size, tone, *args):
"""Generate a news article based on provided data"""
try:
# Ensure size is integer
if isinstance(size, float):
size = int(size)
elif not isinstance(size, int):
size = 250 # Default size
# Check if models are initialized
model_manager.check_llm_initialized()
# Prepare data structure for inputs
knowledge_base = {
"instructions": instructions or "",
"facts": facts or "",
"document_content": [],
"audio_data": [],
"url_content": [],
"social_content": []
}
# Define the indices for parsing args
num_audios = 5 * 3
num_social_urls = 3 * 3
num_urls = 5
# Parse arguments
args = list(args) # Convert tuple to list for easier manipulation
# Ensure we have enough arguments
while len(args) < (num_audios + num_social_urls + num_urls + 5):
args.append("")
audios = args[:num_audios]
social_urls = args[num_audios:num_audios+num_social_urls]
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
documents = args[num_audios+num_social_urls+num_urls:]
# Process URLs with progress reporting
logger.info("Processing URLs...")
for url in urls:
if url and isinstance(url, str) and url.strip():
content = read_url(url)
if content and not content.startswith("Error"):
knowledge_base["url_content"].append(content)
# Process documents
logger.info("Processing documents...")
for document in documents:
if document and hasattr(document, 'name'):
content = read_document(document.name)
if content and not content.startswith("Error"):
knowledge_base["document_content"].append(content)
# Process audio/video files
logger.info("Processing audio/video files...")
for i in range(0, len(audios), 3):
if i+2 < len(audios): # Ensure we have complete set of 3 elements
audio_file, name, position = audios[i:i+3]
if audio_file and hasattr(audio_file, 'name'):
knowledge_base["audio_data"].append({
"audio": audio_file,
"name": name or "Unknown",
"position": position or "Not specified"
})
# Process social media content
logger.info("Processing social media content...")
for i in range(0, len(social_urls), 3):
if i+2 < len(social_urls): # Ensure we have complete set of 3 elements
social_url, social_name, social_context = social_urls[i:i+3]
if social_url and isinstance(social_url, str) and social_url.strip():
social_content = process_social_content(social_url)
if social_content:
knowledge_base["social_content"].append({
"url": social_url,
"name": social_name or "Unknown",
"context": social_context or "Not specified",
"text": social_content.get("text", ""),
"video": social_content.get("video", "")
})
# Prepare transcriptions text
transcriptions_text = ""
raw_transcriptions = ""
# Process audio data transcriptions
logger.info("Transcribing audio...")
for idx, data in enumerate(knowledge_base["audio_data"]):
if data["audio"] is not None:
transcription = transcribe_audio(data["audio"])
if transcription and not transcription.startswith("Error"):
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n'
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
# Process social media content transcriptions
for idx, data in enumerate(knowledge_base["social_content"]):
if data["text"] and not str(data["text"]).startswith("Error"):
# Truncate long texts for the prompt
text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"]
social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n'
transcriptions_text += social_text
raw_transcriptions += social_text
if data["video"] and not str(data["video"]).startswith("Error"):
video_transcription = f'[Social media {idx+1} - video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n\n'
transcriptions_text += video_transcription
raw_transcriptions += video_transcription
# Combine document content and URL content (with truncation for very long content)
document_summaries = []
for idx, doc in enumerate(knowledge_base["document_content"]):
# Truncate long documents
if len(doc) > 1000:
doc_excerpt = doc[:1000] + "... [document continues]"
else:
doc_excerpt = doc
document_summaries.append(f"[Document {idx+1}]: {doc_excerpt}")
document_content = "\n\n".join(document_summaries)
url_summaries = []
for idx, url_content in enumerate(knowledge_base["url_content"]):
# Truncate long URL content
if len(url_content) > 1000:
url_excerpt = url_content[:1000] + "... [content continues]"
else:
url_excerpt = url_content
url_summaries.append(f"[URL {idx+1}]: {url_excerpt}")
url_content = "\n\n".join(url_summaries)
# Create prompt for the model
prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information:
Instructions: {knowledge_base["instructions"]}
Facts: {knowledge_base["facts"]}
Additional content from documents:
{document_content}
Additional content from URLs:
{url_content}
Use these transcriptions as direct and indirect quotes:
{transcriptions_text}
Follow these requirements:
- Write a title
- Write a 15-word hook that complements the title
- Write the body with approximately {size} words
- Use a {tone} tone
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
- Use at least 80% direct quotes (in quotation marks)
- Use proper journalistic style
- Do not invent information
- Be rigorous with the provided facts [/INST]"""
# Generate with standard pipeline
try:
logger.info("Generating news article...")
# Set max length based on requested size
max_length = min(len(prompt.split()) + size * 2, 2048)
# Generate using the pipeline
outputs = model_manager.pipeline(
prompt,
max_length=max_length,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
pad_token_id=model_manager.tokenizer.eos_token_id,
num_return_sequences=1
)
# Extract generated text
generated_text = outputs[0]['generated_text']
# Clean up the result by removing the prompt
if "[/INST]" in generated_text:
news_article = generated_text.split("[/INST]")[1].strip()
else:
# Try to extract the text after the prompt
prompt_words = prompt.split()[:50] # Use first 50 words to identify
prompt_fragment = " ".join(prompt_words)
if prompt_fragment in generated_text:
news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
else:
news_article = generated_text
logger.info(f"News generation completed: {len(news_article)} chars")
except Exception as gen_error:
logger.error(f"Error in text generation: {str(gen_error)}")
raise
return news_article, raw_transcriptions
except Exception as e:
logger.error(f"Error generating news: {str(e)}")
try:
# Reset models to recover from errors
model_manager.reset_models(force=True)
except Exception as reset_error:
logger.error(f"Failed to reset models: {str(reset_error)}")
return f"Error generando la noticia: {str(e)}", "Error procesando las transcripciones."
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“° NewsIA - Generador de Noticias IA")
gr.Markdown("Crea noticias profesionales a partir de mΓΊltiples fuentes de informaciΓ³n.")
with gr.Row():
with gr.Column(scale=2):
instrucciones = gr.Textbox(
label="Instrucciones para la noticia",
placeholder="Escribe instrucciones especΓ­ficas para la generaciΓ³n de tu noticia",
lines=2,
value=""
)
hechos = gr.Textbox(
label="Hechos principales",
placeholder="Describe los hechos mΓ‘s importantes que debe incluir la noticia",
lines=4,
value=""
)
with gr.Row():
tamaΓ±o = gr.Slider(
label="Longitud aproximada (palabras)",
minimum=100,
maximum=500,
value=250,
step=50
)
tono = gr.Dropdown(
label="Tono de la noticia",
choices=["serio", "neutral", "divertido", "formal", "informal", "urgente"],
value="neutral"
)
with gr.Column(scale=3):
# Inicializamos la lista de inputs con valores conocidos
inputs_list = []
inputs_list.append(instrucciones)
inputs_list.append(hechos)
inputs_list.append(tamaΓ±o)
inputs_list.append(tono)
with gr.Tabs():
with gr.TabItem("πŸ“ Documentos"):
documentos = []
for i in range(1, 6): # Mantenemos 5 documentos como en el original
documento = gr.File(
label=f"Documento {i}",
file_types=["pdf", "docx", "xlsx", "csv"],
file_count="single",
value=None
)
documentos.append(documento)
inputs_list.append(documento)
with gr.TabItem("πŸ”Š Audio/Video"):
for i in range(1, 6): # Mantenemos 5 fuentes como en el original
with gr.Group():
gr.Markdown(f"**Fuente {i}**")
file = gr.File(
label=f"Audio/Video {i}",
file_types=["audio", "video"],
value=None
)
with gr.Row():
nombre = gr.Textbox(
label="Nombre",
placeholder="Nombre del entrevistado",
value=""
)
cargo = gr.Textbox(
label="Cargo/Rol",
placeholder="Cargo o rol",
value=""
)
inputs_list.append(file)
inputs_list.append(nombre)
inputs_list.append(cargo)
with gr.TabItem("🌐 URLs"):
for i in range(1, 6): # Mantenemos 5 URLs como en el original
url = gr.Textbox(
label=f"URL {i}",
placeholder="https://...",
value=""
)
inputs_list.append(url)
with gr.TabItem("πŸ“± Redes Sociales"):
for i in range(1, 4): # Mantenemos 3 redes sociales como en el original
with gr.Group():
gr.Markdown(f"**Red Social {i}**")
social_url = gr.Textbox(
label=f"URL",
placeholder="https://...",
value=""
)
with gr.Row():
social_nombre = gr.Textbox(
label=f"Nombre/Cuenta",
placeholder="Nombre de la persona o cuenta",
value=""
)
social_contexto = gr.Textbox(
label=f"Contexto",
placeholder="Contexto relevante",
value=""
)
inputs_list.append(social_url)
inputs_list.append(social_nombre)
inputs_list.append(social_contexto)
with gr.Row():
generar = gr.Button("✨ Generar Noticia", variant="primary")
reset = gr.Button("πŸ”„ Limpiar Todo")
with gr.Tabs():
with gr.TabItem("πŸ“„ Noticia Generada"):
noticia_output = gr.Textbox(
label="Borrador de la noticia",
lines=15,
show_copy_button=True,
value=""
)
with gr.TabItem("πŸŽ™οΈ Transcripciones"):
transcripciones_output = gr.Textbox(
label="Transcripciones de fuentes",
lines=10,
show_copy_button=True,
value=""
)
# Set up event handlers
generar.click(
fn=generate_news,
inputs=inputs_list,
outputs=[noticia_output, transcripciones_output]
)
# Reset functionality to clear all inputs
def reset_all():
return [""] * len(inputs_list) + ["", ""]
reset.click(
fn=reset_all,
inputs=None,
outputs=inputs_list + [noticia_output, transcripciones_output]
)
return demo
if __name__ == "__main__":
try:
# Try initializing whisper model on startup
model_manager.initialize_whisper()
except Exception as e:
logger.warning(f"Initial whisper model loading failed: {str(e)}")
demo = create_demo()
demo.queue(concurrency_count=1, max_size=5)
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)