Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import logging | |
import os | |
import tempfile | |
import pandas as pd | |
import requests | |
from bs4 import BeautifulSoup | |
import torch | |
import whisper | |
import subprocess | |
from pydub import AudioSegment | |
import fitz | |
import docx | |
import yt_dlp | |
from functools import lru_cache | |
import gc | |
import time | |
from huggingface_hub import login | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Login to Hugging Face Hub if token is available | |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN') | |
if HUGGINGFACE_TOKEN: | |
login(token=HUGGINGFACE_TOKEN) | |
class ModelManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(ModelManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
self.tokenizer = None | |
self.model = None | |
self.pipeline = None | |
self.whisper_model = None | |
self._initialized = True | |
self.last_used = time.time() | |
def initialize_llm(self): | |
"""Initialize LLM model with standard transformers""" | |
try: | |
# Use small model for ZeroGPU compatibility | |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=HUGGINGFACE_TOKEN, | |
use_fast=True | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Basic memory settings for ZeroGPU | |
logger.info("Loading model...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
token=HUGGINGFACE_TOKEN, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
# Optimizations for ZeroGPU | |
max_memory={0: "4GB"}, | |
offload_folder="offload", | |
offload_state_dict=True | |
) | |
# Create text generation pipeline | |
logger.info("Creating pipeline...") | |
self.pipeline = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
max_length=1024 | |
) | |
logger.info("LLM initialized successfully") | |
self.last_used = time.time() | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {str(e)}") | |
raise | |
def initialize_whisper(self): | |
"""Initialize Whisper model for audio transcription""" | |
try: | |
logger.info("Loading Whisper model...") | |
# Using tiny model for efficiency but can be changed based on needs | |
self.whisper_model = whisper.load_model( | |
"tiny", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
download_root="/tmp/whisper" | |
) | |
logger.info("Whisper model initialized successfully") | |
self.last_used = time.time() | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing Whisper: {str(e)}") | |
raise | |
def check_llm_initialized(self): | |
"""Check if LLM is initialized and initialize if needed""" | |
if self.tokenizer is None or self.model is None or self.pipeline is None: | |
logger.info("LLM not initialized, initializing...") | |
self.initialize_llm() | |
self.last_used = time.time() | |
def check_whisper_initialized(self): | |
"""Check if Whisper model is initialized and initialize if needed""" | |
if self.whisper_model is None: | |
logger.info("Whisper model not initialized, initializing...") | |
self.initialize_whisper() | |
self.last_used = time.time() | |
def reset_models(self, force=False): | |
"""Reset models to free memory if they haven't been used recently""" | |
current_time = time.time() | |
# Only reset if forced or models haven't been used for 10 minutes | |
if force or (current_time - self.last_used > 600): | |
try: | |
logger.info("Resetting models to free memory...") | |
if hasattr(self, 'model') and self.model is not None: | |
del self.model | |
if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
del self.tokenizer | |
if hasattr(self, 'pipeline') and self.pipeline is not None: | |
del self.pipeline | |
if hasattr(self, 'whisper_model') and self.whisper_model is not None: | |
del self.whisper_model | |
self.tokenizer = None | |
self.model = None | |
self.pipeline = None | |
self.whisper_model = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
logger.info("Models reset successfully") | |
except Exception as e: | |
logger.error(f"Error resetting models: {str(e)}") | |
# Create global model manager instance | |
model_manager = ModelManager() | |
def download_social_media_video(url): | |
"""Download a video from social media.""" | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
try: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(url, download=True) | |
audio_file = f"{info_dict['id']}.mp3" | |
logger.info(f"Video downloaded successfully: {audio_file}") | |
return audio_file | |
except Exception as e: | |
logger.error(f"Error downloading video: {str(e)}") | |
raise | |
def convert_video_to_audio(video_file): | |
"""Convert a video file to audio using ffmpeg directly.""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
output_file = temp_file.name | |
# Use ffmpeg directly via subprocess | |
command = [ | |
"ffmpeg", | |
"-i", video_file, | |
"-q:a", "0", | |
"-map", "a", | |
"-vn", | |
output_file, | |
"-y" # Overwrite output file if it exists | |
] | |
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
logger.info(f"Video converted to audio: {output_file}") | |
return output_file | |
except Exception as e: | |
logger.error(f"Error converting video: {str(e)}") | |
raise | |
def preprocess_audio(audio_file): | |
"""Preprocess the audio file to improve quality.""" | |
try: | |
audio = AudioSegment.from_file(audio_file) | |
audio = audio.apply_gain(-audio.dBFS + (-20)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
audio.export(temp_file.name, format="mp3") | |
logger.info(f"Audio preprocessed: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error preprocessing audio: {str(e)}") | |
raise | |
def transcribe_audio(file): | |
"""Transcribe an audio or video file.""" | |
try: | |
model_manager.check_whisper_initialized() | |
if isinstance(file, str) and file.startswith('http'): | |
file_path = download_social_media_video(file) | |
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): | |
file_path = convert_video_to_audio(file) | |
elif file is not None: # Handle file object from Gradio | |
file_path = preprocess_audio(file.name) | |
else: | |
return "" # Return empty string for None input | |
logger.info(f"Transcribing audio: {file_path}") | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Audio file not found: {file_path}") | |
with torch.inference_mode(): | |
result = model_manager.whisper_model.transcribe(file_path) | |
if not result: | |
raise RuntimeError("Transcription failed to produce results") | |
transcription = result.get("text", "Error in transcription") | |
logger.info(f"Transcription completed: {transcription[:50]}...") | |
# Clean up temp file | |
try: | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
except Exception as e: | |
logger.warning(f"Could not remove temp file {file_path}: {str(e)}") | |
return transcription | |
except Exception as e: | |
logger.error(f"Error transcribing: {str(e)}") | |
return f"Error processing the file: {str(e)}" | |
def read_document(document_path): | |
"""Read the content of a document.""" | |
try: | |
if document_path.endswith(".pdf"): | |
doc = fitz.open(document_path) | |
return "\n".join([page.get_text() for page in doc]) | |
elif document_path.endswith(".docx"): | |
doc = docx.Document(document_path) | |
return "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
elif document_path.endswith((".xlsx", ".xls")): | |
return pd.read_excel(document_path).to_string() | |
elif document_path.endswith(".csv"): | |
return pd.read_csv(document_path).to_string() | |
else: | |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document." | |
except Exception as e: | |
logger.error(f"Error reading document: {str(e)}") | |
return f"Error reading document: {str(e)}" | |
def read_url(url): | |
"""Read the content of a URL.""" | |
if not url or url.strip() == "": | |
return "" | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(url, headers=headers, timeout=15) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove non-content elements | |
for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]): | |
element.extract() | |
# Extract main content | |
main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"]) | |
if main_content: | |
text = main_content.get_text(separator='\n', strip=True) | |
else: | |
text = soup.get_text(separator='\n', strip=True) | |
# Clean up whitespace | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
text = '\n'.join(lines) | |
return text[:10000] # Limit to 10k chars to avoid huge inputs | |
except Exception as e: | |
logger.error(f"Error reading URL: {str(e)}") | |
return f"Error reading URL: {str(e)}" | |
def process_social_content(url): | |
"""Process social media content.""" | |
if not url or url.strip() == "": | |
return None | |
try: | |
text_content = read_url(url) | |
try: | |
video_content = transcribe_audio(url) | |
except Exception as e: | |
logger.error(f"Error processing video content: {str(e)}") | |
video_content = None | |
return { | |
"text": text_content, | |
"video": video_content | |
} | |
except Exception as e: | |
logger.error(f"Error processing social content: {str(e)}") | |
return None | |
def generate_news(instructions, facts, size, tone, *args): | |
"""Generate a news article based on provided data""" | |
try: | |
# Ensure size is integer | |
if isinstance(size, float): | |
size = int(size) | |
elif not isinstance(size, int): | |
size = 250 # Default size | |
# Check if models are initialized | |
model_manager.check_llm_initialized() | |
# Prepare data structure for inputs | |
knowledge_base = { | |
"instructions": instructions or "", | |
"facts": facts or "", | |
"document_content": [], | |
"audio_data": [], | |
"url_content": [], | |
"social_content": [] | |
} | |
# Define the indices for parsing args | |
num_audios = 5 * 3 | |
num_social_urls = 3 * 3 | |
num_urls = 5 | |
# Parse arguments | |
args = list(args) # Convert tuple to list for easier manipulation | |
# Ensure we have enough arguments | |
while len(args) < (num_audios + num_social_urls + num_urls + 5): | |
args.append("") | |
audios = args[:num_audios] | |
social_urls = args[num_audios:num_audios+num_social_urls] | |
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls] | |
documents = args[num_audios+num_social_urls+num_urls:] | |
# Process URLs with progress reporting | |
logger.info("Processing URLs...") | |
for url in urls: | |
if url and isinstance(url, str) and url.strip(): | |
content = read_url(url) | |
if content and not content.startswith("Error"): | |
knowledge_base["url_content"].append(content) | |
# Process documents | |
logger.info("Processing documents...") | |
for document in documents: | |
if document and hasattr(document, 'name'): | |
content = read_document(document.name) | |
if content and not content.startswith("Error"): | |
knowledge_base["document_content"].append(content) | |
# Process audio/video files | |
logger.info("Processing audio/video files...") | |
for i in range(0, len(audios), 3): | |
if i+2 < len(audios): # Ensure we have complete set of 3 elements | |
audio_file, name, position = audios[i:i+3] | |
if audio_file and hasattr(audio_file, 'name'): | |
knowledge_base["audio_data"].append({ | |
"audio": audio_file, | |
"name": name or "Unknown", | |
"position": position or "Not specified" | |
}) | |
# Process social media content | |
logger.info("Processing social media content...") | |
for i in range(0, len(social_urls), 3): | |
if i+2 < len(social_urls): # Ensure we have complete set of 3 elements | |
social_url, social_name, social_context = social_urls[i:i+3] | |
if social_url and isinstance(social_url, str) and social_url.strip(): | |
social_content = process_social_content(social_url) | |
if social_content: | |
knowledge_base["social_content"].append({ | |
"url": social_url, | |
"name": social_name or "Unknown", | |
"context": social_context or "Not specified", | |
"text": social_content.get("text", ""), | |
"video": social_content.get("video", "") | |
}) | |
# Prepare transcriptions text | |
transcriptions_text = "" | |
raw_transcriptions = "" | |
# Process audio data transcriptions | |
logger.info("Transcribing audio...") | |
for idx, data in enumerate(knowledge_base["audio_data"]): | |
if data["audio"] is not None: | |
transcription = transcribe_audio(data["audio"]) | |
if transcription and not transcription.startswith("Error"): | |
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
# Process social media content transcriptions | |
for idx, data in enumerate(knowledge_base["social_content"]): | |
if data["text"] and not str(data["text"]).startswith("Error"): | |
# Truncate long texts for the prompt | |
text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"] | |
social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n' | |
transcriptions_text += social_text | |
raw_transcriptions += social_text | |
if data["video"] and not str(data["video"]).startswith("Error"): | |
video_transcription = f'[Social media {idx+1} - video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n\n' | |
transcriptions_text += video_transcription | |
raw_transcriptions += video_transcription | |
# Combine document content and URL content (with truncation for very long content) | |
document_summaries = [] | |
for idx, doc in enumerate(knowledge_base["document_content"]): | |
# Truncate long documents | |
if len(doc) > 1000: | |
doc_excerpt = doc[:1000] + "... [document continues]" | |
else: | |
doc_excerpt = doc | |
document_summaries.append(f"[Document {idx+1}]: {doc_excerpt}") | |
document_content = "\n\n".join(document_summaries) | |
url_summaries = [] | |
for idx, url_content in enumerate(knowledge_base["url_content"]): | |
# Truncate long URL content | |
if len(url_content) > 1000: | |
url_excerpt = url_content[:1000] + "... [content continues]" | |
else: | |
url_excerpt = url_content | |
url_summaries.append(f"[URL {idx+1}]: {url_excerpt}") | |
url_content = "\n\n".join(url_summaries) | |
# Create prompt for the model | |
prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information: | |
Instructions: {knowledge_base["instructions"]} | |
Facts: {knowledge_base["facts"]} | |
Additional content from documents: | |
{document_content} | |
Additional content from URLs: | |
{url_content} | |
Use these transcriptions as direct and indirect quotes: | |
{transcriptions_text} | |
Follow these requirements: | |
- Write a title | |
- Write a 15-word hook that complements the title | |
- Write the body with approximately {size} words | |
- Use a {tone} tone | |
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph | |
- Use at least 80% direct quotes (in quotation marks) | |
- Use proper journalistic style | |
- Do not invent information | |
- Be rigorous with the provided facts [/INST]""" | |
# Generate with standard pipeline | |
try: | |
logger.info("Generating news article...") | |
# Set max length based on requested size | |
max_length = min(len(prompt.split()) + size * 2, 2048) | |
# Generate using the pipeline | |
outputs = model_manager.pipeline( | |
prompt, | |
max_length=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2, | |
pad_token_id=model_manager.tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
# Extract generated text | |
generated_text = outputs[0]['generated_text'] | |
# Clean up the result by removing the prompt | |
if "[/INST]" in generated_text: | |
news_article = generated_text.split("[/INST]")[1].strip() | |
else: | |
# Try to extract the text after the prompt | |
prompt_words = prompt.split()[:50] # Use first 50 words to identify | |
prompt_fragment = " ".join(prompt_words) | |
if prompt_fragment in generated_text: | |
news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip() | |
else: | |
news_article = generated_text | |
logger.info(f"News generation completed: {len(news_article)} chars") | |
except Exception as gen_error: | |
logger.error(f"Error in text generation: {str(gen_error)}") | |
raise | |
return news_article, raw_transcriptions | |
except Exception as e: | |
logger.error(f"Error generating news: {str(e)}") | |
try: | |
# Reset models to recover from errors | |
model_manager.reset_models(force=True) | |
except Exception as reset_error: | |
logger.error(f"Failed to reset models: {str(reset_error)}") | |
return f"Error generando la noticia: {str(e)}", "Error procesando las transcripciones." | |
def create_demo(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π° NewsIA - Generador de Noticias IA") | |
gr.Markdown("Crea noticias profesionales a partir de mΓΊltiples fuentes de informaciΓ³n.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instrucciones = gr.Textbox( | |
label="Instrucciones para la noticia", | |
placeholder="Escribe instrucciones especΓficas para la generaciΓ³n de tu noticia", | |
lines=2, | |
value="" | |
) | |
hechos = gr.Textbox( | |
label="Hechos principales", | |
placeholder="Describe los hechos mΓ‘s importantes que debe incluir la noticia", | |
lines=4, | |
value="" | |
) | |
with gr.Row(): | |
tamaΓ±o = gr.Slider( | |
label="Longitud aproximada (palabras)", | |
minimum=100, | |
maximum=500, | |
value=250, | |
step=50 | |
) | |
tono = gr.Dropdown( | |
label="Tono de la noticia", | |
choices=["serio", "neutral", "divertido", "formal", "informal", "urgente"], | |
value="neutral" | |
) | |
with gr.Column(scale=3): | |
# Inicializamos la lista de inputs con valores conocidos | |
inputs_list = [] | |
inputs_list.append(instrucciones) | |
inputs_list.append(hechos) | |
inputs_list.append(tamaΓ±o) | |
inputs_list.append(tono) | |
with gr.Tabs(): | |
with gr.TabItem("π Documentos"): | |
documentos = [] | |
for i in range(1, 6): # Mantenemos 5 documentos como en el original | |
documento = gr.File( | |
label=f"Documento {i}", | |
file_types=["pdf", "docx", "xlsx", "csv"], | |
file_count="single", | |
value=None | |
) | |
documentos.append(documento) | |
inputs_list.append(documento) | |
with gr.TabItem("π Audio/Video"): | |
for i in range(1, 6): # Mantenemos 5 fuentes como en el original | |
with gr.Group(): | |
gr.Markdown(f"**Fuente {i}**") | |
file = gr.File( | |
label=f"Audio/Video {i}", | |
file_types=["audio", "video"], | |
value=None | |
) | |
with gr.Row(): | |
nombre = gr.Textbox( | |
label="Nombre", | |
placeholder="Nombre del entrevistado", | |
value="" | |
) | |
cargo = gr.Textbox( | |
label="Cargo/Rol", | |
placeholder="Cargo o rol", | |
value="" | |
) | |
inputs_list.append(file) | |
inputs_list.append(nombre) | |
inputs_list.append(cargo) | |
with gr.TabItem("π URLs"): | |
for i in range(1, 6): # Mantenemos 5 URLs como en el original | |
url = gr.Textbox( | |
label=f"URL {i}", | |
placeholder="https://...", | |
value="" | |
) | |
inputs_list.append(url) | |
with gr.TabItem("π± Redes Sociales"): | |
for i in range(1, 4): # Mantenemos 3 redes sociales como en el original | |
with gr.Group(): | |
gr.Markdown(f"**Red Social {i}**") | |
social_url = gr.Textbox( | |
label=f"URL", | |
placeholder="https://...", | |
value="" | |
) | |
with gr.Row(): | |
social_nombre = gr.Textbox( | |
label=f"Nombre/Cuenta", | |
placeholder="Nombre de la persona o cuenta", | |
value="" | |
) | |
social_contexto = gr.Textbox( | |
label=f"Contexto", | |
placeholder="Contexto relevante", | |
value="" | |
) | |
inputs_list.append(social_url) | |
inputs_list.append(social_nombre) | |
inputs_list.append(social_contexto) | |
with gr.Row(): | |
generar = gr.Button("β¨ Generar Noticia", variant="primary") | |
reset = gr.Button("π Limpiar Todo") | |
with gr.Tabs(): | |
with gr.TabItem("π Noticia Generada"): | |
noticia_output = gr.Textbox( | |
label="Borrador de la noticia", | |
lines=15, | |
show_copy_button=True, | |
value="" | |
) | |
with gr.TabItem("ποΈ Transcripciones"): | |
transcripciones_output = gr.Textbox( | |
label="Transcripciones de fuentes", | |
lines=10, | |
show_copy_button=True, | |
value="" | |
) | |
# Set up event handlers | |
generar.click( | |
fn=generate_news, | |
inputs=inputs_list, | |
outputs=[noticia_output, transcripciones_output] | |
) | |
# Reset functionality to clear all inputs | |
def reset_all(): | |
return [""] * len(inputs_list) + ["", ""] | |
reset.click( | |
fn=reset_all, | |
inputs=None, | |
outputs=inputs_list + [noticia_output, transcripciones_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
# Try initializing whisper model on startup | |
model_manager.initialize_whisper() | |
except Exception as e: | |
logger.warning(f"Initial whisper model loading failed: {str(e)}") | |
demo = create_demo() | |
demo.queue(concurrency_count=1, max_size=5) | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |