Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import logging | |
import os | |
import tempfile | |
import pandas as pd | |
import requests | |
from bs4 import BeautifulSoup | |
import torch | |
import whisper | |
import subprocess | |
from pydub import AudioSegment | |
import fitz | |
import docx | |
import yt_dlp | |
from functools import lru_cache | |
import gc | |
import time | |
from huggingface_hub import login | |
from unsloth import FastLanguageModel | |
from transformers import AutoTokenizer | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Login to Hugging Face Hub if token is available | |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN') | |
if HUGGINGFACE_TOKEN: | |
login(token=HUGGINGFACE_TOKEN) | |
class ModelManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(ModelManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
self.tokenizer = None | |
self.model = None | |
self.pipeline = None | |
self.whisper_model = None | |
self._initialized = True | |
self.last_used = time.time() | |
def initialize_llm(self): | |
"""Initialize LLM model with Unsloth optimization""" | |
try: | |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
logger.info("Loading Unsloth-optimized model...") | |
self.model, self.tokenizer = FastLanguageModel.from_pretrained( | |
model_name = MODEL_NAME, | |
max_seq_length = 2048, | |
dtype = torch.float16, | |
load_in_4bit = True, | |
token = HUGGINGFACE_TOKEN, | |
) | |
# Enable LoRA for better ZeroGPU performance | |
self.model = FastLanguageModel.get_peft_model( | |
self.model, | |
r = 16, | |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj"], | |
lora_alpha = 16, | |
lora_dropout = 0, | |
bias = "none", | |
use_gradient_checkpointing = True, | |
random_state = 3407, | |
max_seq_length = 2048, | |
) | |
logger.info("LLM initialized successfully with Unsloth") | |
self.last_used = time.time() | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {str(e)}") | |
raise | |
def initialize_whisper(self): | |
"""Initialize Whisper model with safety fix""" | |
try: | |
logger.info("Loading Whisper model...") | |
# Load with weights_only=True for security | |
self.whisper_model = whisper.load_model( | |
"tiny", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
download_root="/tmp/whisper", | |
weights_only=True # Security fix | |
) | |
logger.info("Whisper model initialized successfully") | |
self.last_used = time.time() | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing Whisper: {str(e)}") | |
raise | |
def check_llm_initialized(self): | |
"""Check if LLM is initialized and initialize if needed""" | |
if self.tokenizer is None or self.model is None: | |
logger.info("LLM not initialized, initializing...") | |
self.initialize_llm() | |
self.last_used = time.time() | |
def check_whisper_initialized(self): | |
"""Check if Whisper model is initialized and initialize if needed""" | |
if self.whisper_model is None: | |
logger.info("Whisper model not initialized, initializing...") | |
self.initialize_whisper() | |
self.last_used = time.time() | |
def reset_models(self, force=False): | |
"""Reset models to free memory if they haven't been used recently""" | |
current_time = time.time() | |
if force or (current_time - self.last_used > 600): | |
try: | |
logger.info("Resetting models to free memory...") | |
if self.model is not None: | |
del self.model | |
if self.tokenizer is not None: | |
del self.tokenizer | |
if self.whisper_model is not None: | |
del self.whisper_model | |
self.tokenizer = None | |
self.model = None | |
self.whisper_model = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
logger.info("Models reset successfully") | |
except Exception as e: | |
logger.error(f"Error resetting models: {str(e)}") | |
model_manager = ModelManager() | |
def download_social_media_video(url): | |
"""Download a video from social media.""" | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
try: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(url, download=True) | |
audio_file = f"{info_dict['id']}.mp3" | |
logger.info(f"Video downloaded successfully: {audio_file}") | |
return audio_file | |
except Exception as e: | |
logger.error(f"Error downloading video: {str(e)}") | |
raise | |
def convert_video_to_audio(video_file): | |
"""Convert a video file to audio using ffmpeg directly.""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
output_file = temp_file.name | |
command = [ | |
"ffmpeg", | |
"-i", video_file, | |
"-q:a", "0", | |
"-map", "a", | |
"-vn", | |
output_file, | |
"-y" | |
] | |
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
logger.info(f"Video converted to audio: {output_file}") | |
return output_file | |
except Exception as e: | |
logger.error(f"Error converting video: {str(e)}") | |
raise | |
def preprocess_audio(audio_file): | |
"""Preprocess the audio file to improve quality.""" | |
try: | |
audio = AudioSegment.from_file(audio_file) | |
audio = audio.apply_gain(-audio.dBFS + (-20)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
audio.export(temp_file.name, format="mp3") | |
logger.info(f"Audio preprocessed: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error preprocessing audio: {str(e)}") | |
raise | |
def transcribe_audio(file): | |
"""Transcribe an audio or video file.""" | |
try: | |
model_manager.check_whisper_initialized() | |
if isinstance(file, str) and file.startswith('http'): | |
file_path = download_social_media_video(file) | |
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): | |
file_path = convert_video_to_audio(file) | |
elif file is not None: | |
file_path = preprocess_audio(file.name) | |
else: | |
return "" | |
logger.info(f"Transcribing audio: {file_path}") | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Audio file not found: {file_path}") | |
with torch.inference_mode(): | |
result = model_manager.whisper_model.transcribe(file_path) | |
transcription = result.get("text", "Error in transcription") | |
logger.info(f"Transcription completed: {transcription[:50]}...") | |
try: | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
except Exception as e: | |
logger.warning(f"Could not remove temp file {file_path}: {str(e)}") | |
return transcription | |
except Exception as e: | |
logger.error(f"Error transcribing: {str(e)}") | |
return f"Error processing the file: {str(e)}" | |
def read_document(document_path): | |
"""Read the content of a document.""" | |
try: | |
if document_path.endswith(".pdf"): | |
doc = fitz.open(document_path) | |
return "\n".join([page.get_text() for page in doc]) | |
elif document_path.endswith(".docx"): | |
doc = docx.Document(document_path) | |
return "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
elif document_path.endswith((".xlsx", ".xls")): | |
return pd.read_excel(document_path).to_string() | |
elif document_path.endswith(".csv"): | |
return pd.read_csv(document_path).to_string() | |
else: | |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document." | |
except Exception as e: | |
logger.error(f"Error reading document: {str(e)}") | |
return f"Error reading document: {str(e)}" | |
def read_url(url): | |
"""Read the content of a URL.""" | |
if not url or url.strip() == "": | |
return "" | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(url, headers=headers, timeout=15) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav"]): | |
element.extract() | |
main_content = soup.find("main") or soup.find("article") or soup.find("div", class_=["content", "main", "article"]) | |
if main_content: | |
text = main_content.get_text(separator='\n', strip=True) | |
else: | |
text = soup.get_text(separator='\n', strip=True) | |
lines = [line.strip() for line in text.split('\n') if line.strip()] | |
text = '\n'.join(lines) | |
return text[:10000] | |
except Exception as e: | |
logger.error(f"Error reading URL: {str(e)}") | |
return f"Error reading URL: {str(e)}" | |
def process_social_content(url): | |
"""Process social media content.""" | |
if not url or url.strip() == "": | |
return None | |
try: | |
text_content = read_url(url) | |
try: | |
video_content = transcribe_audio(url) | |
except Exception as e: | |
logger.error(f"Error processing video content: {str(e)}") | |
video_content = None | |
return { | |
"text": text_content, | |
"video": video_content | |
} | |
except Exception as e: | |
logger.error(f"Error processing social content: {str(e)}") | |
return None | |
def generate_news(instructions, facts, size, tone, *args): | |
"""Generate a news article based on provided data""" | |
try: | |
if isinstance(size, float): | |
size = int(size) | |
elif not isinstance(size, int): | |
size = 250 | |
model_manager.check_llm_initialized() | |
knowledge_base = { | |
"instructions": instructions or "", | |
"facts": facts or "", | |
"document_content": [], | |
"audio_data": [], | |
"url_content": [], | |
"social_content": [] | |
} | |
num_audios = 5 * 3 | |
num_social_urls = 3 * 3 | |
num_urls = 5 | |
args = list(args) | |
while len(args) < (num_audios + num_social_urls + num_urls + 5): | |
args.append("") | |
audios = args[:num_audios] | |
social_urls = args[num_audios:num_audios+num_social_urls] | |
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls] | |
documents = args[num_audios+num_social_urls+num_urls:] | |
logger.info("Processing URLs...") | |
for url in urls: | |
if url and isinstance(url, str) and url.strip(): | |
content = read_url(url) | |
if content and not content.startswith("Error"): | |
knowledge_base["url_content"].append(content) | |
logger.info("Processing documents...") | |
for document in documents: | |
if document and hasattr(document, 'name'): | |
content = read_document(document.name) | |
if content and not content.startswith("Error"): | |
knowledge_base["document_content"].append(content) | |
logger.info("Processing audio/video files...") | |
for i in range(0, len(audios), 3): | |
if i+2 < len(audios): | |
audio_file, name, position = audios[i:i+3] | |
if audio_file and hasattr(audio_file, 'name'): | |
knowledge_base["audio_data"].append({ | |
"audio": audio_file, | |
"name": name or "Unknown", | |
"position": position or "Not specified" | |
}) | |
logger.info("Processing social media content...") | |
for i in range(0, len(social_urls), 3): | |
if i+2 < len(social_urls): | |
social_url, social_name, social_context = social_urls[i:i+3] | |
if social_url and isinstance(social_url, str) and social_url.strip(): | |
social_content = process_social_content(social_url) | |
if social_content: | |
knowledge_base["social_content"].append({ | |
"url": social_url, | |
"name": social_name or "Unknown", | |
"context": social_context or "Not specified", | |
"text": social_content.get("text", ""), | |
"video": social_content.get("video", "") | |
}) | |
transcriptions_text = "" | |
raw_transcriptions = "" | |
logger.info("Transcribing audio...") | |
for idx, data in enumerate(knowledge_base["audio_data"]): | |
if data["audio"] is not None: | |
transcription = transcribe_audio(data["audio"]) | |
if transcription and not transcription.startswith("Error"): | |
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
for idx, data in enumerate(knowledge_base["social_content"]): | |
if data["text"] and not str(data["text"]).startswith("Error"): | |
text_excerpt = data["text"][:500] + "..." if len(data["text"]) > 500 else data["text"] | |
social_text = f'[Social media {idx+1} - text]: "{text_excerpt}" - {data["name"]}, {data["context"]}\n\n' | |
transcriptions_text += social_text | |
raw_transcriptions += social_text | |
if data["video"] and not str(data["video"]).startswith("Error"): | |
video_transcription = f'[Social media {idx+1} - video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n\n' | |
transcriptions_text += video_transcription | |
raw_transcriptions += video_transcription | |
document_summaries = [] | |
for idx, doc in enumerate(knowledge_base["document_content"]): | |
if len(doc) > 1000: | |
doc_excerpt = doc[:1000] + "... [document continues]" | |
else: | |
doc_excerpt = doc | |
document_summaries.append(f"[Document {idx+1}]: {doc_excerpt}") | |
document_content = "\n\n".join(document_summaries) | |
url_summaries = [] | |
for idx, url_content in enumerate(knowledge_base["url_content"]): | |
if len(url_content) > 1000: | |
url_excerpt = url_content[:1000] + "... [content continues]" | |
else: | |
url_excerpt = url_content | |
url_summaries.append(f"[URL {idx+1}]: {url_excerpt}") | |
url_content = "\n\n".join(url_summaries) | |
prompt = f"""<s>[INST] You are a professional news writer. Write a news article based on the following information: | |
Instructions: {knowledge_base["instructions"]} | |
Facts: {knowledge_base["facts"]} | |
Additional content from documents: | |
{document_content} | |
Additional content from URLs: | |
{url_content} | |
Use these transcriptions as direct and indirect quotes: | |
{transcriptions_text} | |
Follow these requirements: | |
- Write a title | |
- Write a 15-word hook that complements the title | |
- Write the body with approximately {size} words | |
- Use a {tone} tone | |
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph | |
- Use at least 80% direct quotes (in quotation marks) | |
- Use proper journalistic style | |
- Do not invent information | |
- Be rigorous with the provided facts [/INST]""" | |
try: | |
logger.info("Generating news article...") | |
max_length = min(len(prompt.split()) + size * 2, 2048) | |
inputs = model_manager.tokenizer( | |
prompt, | |
return_tensors = "pt", | |
padding = True, | |
truncation = True, | |
max_length = 2048, | |
).to("cuda") | |
outputs = model_manager.model.generate( | |
**inputs, | |
max_new_tokens = size + 100, | |
temperature = 0.7, | |
do_sample = True, | |
pad_token_id = model_manager.tokenizer.eos_token_id, | |
) | |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens = True) | |
if "[/INST]" in generated_text: | |
news_article = generated_text.split("[/INST]")[1].strip() | |
else: | |
prompt_fragment = " ".join(prompt.split()[:50]) | |
if prompt_fragment in generated_text: | |
news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip() | |
else: | |
news_article = generated_text | |
logger.info(f"News generation completed: {len(news_article)} chars") | |
except Exception as gen_error: | |
logger.error(f"Error in text generation: {str(gen_error)}") | |
raise | |
return news_article, raw_transcriptions | |
except Exception as e: | |
logger.error(f"Error generating news: {str(e)}") | |
try: | |
model_manager.reset_models(force=True) | |
except Exception as reset_error: | |
logger.error(f"Failed to reset models: {str(reset_error)}") | |
return f"Error generating news: {str(e)}", "Error processing transcriptions." | |
def create_demo(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π° NewsIA - AI News Generator") | |
gr.Markdown("Create professional news articles from multiple sources.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instructions = gr.Textbox( | |
label="News Instructions", | |
placeholder="Enter specific instructions for news generation", | |
lines=2 | |
) | |
facts = gr.Textbox( | |
label="Key Facts", | |
placeholder="Describe the most important facts to include", | |
lines=4 | |
) | |
with gr.Row(): | |
size = gr.Slider( | |
label="Approximate Length (words)", | |
minimum=100, | |
maximum=500, | |
value=250, | |
step=50 | |
) | |
tone = gr.Dropdown( | |
label="News Tone", | |
choices=["serious", "neutral", "funny", "formal", "informal", "urgent"], | |
value="neutral" | |
) | |
with gr.Column(scale=3): | |
inputs_list = [] | |
inputs_list.extend([instructions, facts, size, tone]) | |
with gr.Tabs(): | |
with gr.TabItem("π Documents"): | |
documents = [] | |
for i in range(1, 6): | |
doc = gr.File( | |
label=f"Document {i}", | |
file_types=["pdf", "docx", "xlsx", "csv"], | |
file_count="single" | |
) | |
documents.append(doc) | |
inputs_list.append(doc) | |
with gr.TabItem("π Audio/Video"): | |
for i in range(1, 6): | |
with gr.Group(): | |
gr.Markdown(f"**Source {i}**") | |
file = gr.File( | |
label=f"Audio/Video {i}", | |
file_types=["audio", "video"] | |
) | |
with gr.Row(): | |
name = gr.Textbox( | |
label="Name", | |
placeholder="Interviewee name" | |
) | |
position = gr.Textbox( | |
label="Position/Role", | |
placeholder="Position or role" | |
) | |
inputs_list.extend([file, name, position]) | |
with gr.TabItem("π URLs"): | |
for i in range(1, 6): | |
url = gr.Textbox( | |
label=f"URL {i}", | |
placeholder="https://..." | |
) | |
inputs_list.append(url) | |
with gr.TabItem("π± Social Media"): | |
for i in range(1, 4): | |
with gr.Group(): | |
gr.Markdown(f"**Social Media {i}**") | |
social_url = gr.Textbox( | |
label="URL", | |
placeholder="https://..." | |
) | |
with gr.Row(): | |
social_name = gr.Textbox( | |
label="Account/Name", | |
placeholder="Account or person name" | |
) | |
social_context = gr.Textbox( | |
label="Context", | |
placeholder="Relevant context" | |
) | |
inputs_list.extend([social_url, social_name, social_context]) | |
with gr.Row(): | |
generate_btn = gr.Button("β¨ Generate News", variant="primary") | |
reset_btn = gr.Button("π Clear All") | |
with gr.Tabs(): | |
with gr.TabItem("π Generated News"): | |
news_output = gr.Textbox( | |
label="News Draft", | |
lines=15, | |
show_copy_button=True | |
) | |
with gr.TabItem("ποΈ Transcriptions"): | |
transcriptions_output = gr.Textbox( | |
label="Source Transcriptions", | |
lines=10, | |
show_copy_button=True | |
) | |
generate_btn.click( | |
fn=generate_news, | |
inputs=inputs_list, | |
outputs=[news_output, transcriptions_output] | |
) | |
def reset_all(): | |
return [None]*len(inputs_list) + ["", ""] | |
reset_btn.click( | |
fn=reset_all, | |
inputs=None, | |
outputs=inputs_list + [news_output, transcriptions_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
model_manager.initialize_whisper() | |
except Exception as e: | |
logger.warning(f"Initial whisper model loading failed: {str(e)}") | |
demo = create_demo() | |
demo.queue(max_size=5) | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |