Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,15 +10,15 @@ import torch
|
|
10 |
import whisper
|
11 |
import subprocess
|
12 |
from pydub import AudioSegment
|
13 |
-
import fitz
|
14 |
import docx
|
15 |
import yt_dlp
|
16 |
from functools import lru_cache
|
17 |
import gc
|
18 |
import time
|
19 |
from huggingface_hub import login
|
20 |
-
from
|
21 |
-
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(
|
@@ -30,626 +30,1171 @@ logger = logging.getLogger(__name__)
|
|
30 |
# Login to Hugging Face Hub if token is available
|
31 |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
|
32 |
if HUGGINGFACE_TOKEN:
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
|
35 |
class ModelManager:
|
36 |
_instance = None
|
37 |
-
|
38 |
def __new__(cls):
|
39 |
if cls._instance is None:
|
40 |
cls._instance = super(ModelManager, cls).__new__(cls)
|
41 |
cls._instance._initialized = False
|
42 |
return cls._instance
|
43 |
-
|
44 |
def __init__(self):
|
45 |
if not self._initialized:
|
46 |
self.tokenizer = None
|
47 |
self.model = None
|
48 |
-
self.
|
49 |
self.whisper_model = None
|
50 |
self._initialized = True
|
51 |
self.last_used = time.time()
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
def initialize_llm(self):
|
55 |
-
"""Initialize LLM model with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
try:
|
|
|
57 |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
58 |
-
|
59 |
-
logger.info("Loading
|
60 |
-
self.
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
)
|
67 |
-
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
use_gradient_checkpointing = True,
|
78 |
-
random_state = 3407,
|
79 |
-
max_seq_length = 2048,
|
80 |
)
|
81 |
-
|
82 |
-
logger.info("LLM initialized successfully
|
83 |
self.last_used = time.time()
|
|
|
84 |
return True
|
85 |
-
|
86 |
except Exception as e:
|
87 |
logger.error(f"Error initializing LLM: {str(e)}")
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
@spaces.GPU()
|
91 |
def initialize_whisper(self):
|
92 |
-
"""Initialize Whisper model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
try:
|
94 |
logger.info("Loading Whisper model...")
|
95 |
-
#
|
|
|
|
|
|
|
|
|
96 |
self.whisper_model = whisper.load_model(
|
97 |
-
"tiny",
|
98 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
99 |
-
download_root="/tmp/whisper"
|
100 |
-
weights_only=True # Security fix
|
101 |
)
|
102 |
logger.info("Whisper model initialized successfully")
|
103 |
self.last_used = time.time()
|
|
|
104 |
return True
|
105 |
except Exception as e:
|
106 |
logger.error(f"Error initializing Whisper: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
raise
|
108 |
|
109 |
def check_llm_initialized(self):
|
110 |
"""Check if LLM is initialized and initialize if needed"""
|
111 |
-
if self.tokenizer is None or self.model is None:
|
112 |
logger.info("LLM not initialized, initializing...")
|
113 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
self.last_used = time.time()
|
115 |
-
|
116 |
def check_whisper_initialized(self):
|
117 |
"""Check if Whisper model is initialized and initialize if needed"""
|
118 |
if self.whisper_model is None:
|
119 |
logger.info("Whisper model not initialized, initializing...")
|
120 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
self.last_used = time.time()
|
122 |
-
|
123 |
def reset_models(self, force=False):
|
124 |
"""Reset models to free memory if they haven't been used recently"""
|
125 |
current_time = time.time()
|
126 |
-
if
|
|
|
127 |
try:
|
128 |
logger.info("Resetting models to free memory...")
|
129 |
-
|
130 |
-
|
|
|
131 |
del self.model
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
del self.tokenizer
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
del self.whisper_model
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
if torch.cuda.is_available():
|
144 |
torch.cuda.empty_cache()
|
145 |
-
torch.cuda.synchronize()
|
146 |
-
|
|
|
147 |
gc.collect()
|
148 |
-
logger.info("Models reset successfully")
|
149 |
-
|
|
|
150 |
except Exception as e:
|
151 |
logger.error(f"Error resetting models: {str(e)}")
|
|
|
152 |
|
|
|
153 |
model_manager = ModelManager()
|
154 |
|
155 |
-
@lru_cache(maxsize=
|
156 |
def download_social_media_video(url):
|
157 |
-
"""Download
|
|
|
|
|
|
|
158 |
ydl_opts = {
|
159 |
'format': 'bestaudio/best',
|
160 |
'postprocessors': [{
|
161 |
'key': 'FFmpegExtractAudio',
|
162 |
'preferredcodec': 'mp3',
|
163 |
-
'preferredquality': '192',
|
164 |
}],
|
165 |
-
'outtmpl':
|
|
|
|
|
|
|
|
|
|
|
166 |
}
|
167 |
try:
|
|
|
168 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
169 |
info_dict = ydl.extract_info(url, download=True)
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
except Exception as e:
|
174 |
-
logger.error(f"
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
def convert_video_to_audio(
|
178 |
"""Convert a video file to audio using ffmpeg directly."""
|
179 |
try:
|
|
|
180 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
183 |
command = [
|
184 |
-
"ffmpeg",
|
185 |
-
"-i",
|
186 |
-
"-
|
187 |
-
"-
|
188 |
-
"-
|
189 |
-
|
190 |
-
"-
|
|
|
|
|
|
|
191 |
]
|
192 |
-
|
193 |
-
subprocess.run(command, check=True,
|
194 |
-
|
195 |
-
logger.
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
except Exception as e:
|
198 |
-
logger.error(f"Error converting video: {str(e)}")
|
199 |
-
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
def preprocess_audio(
|
202 |
-
"""Preprocess the audio file
|
203 |
try:
|
204 |
-
audio
|
205 |
-
audio =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
210 |
except Exception as e:
|
211 |
-
logger.error(f"Error preprocessing audio: {str(e)}")
|
|
|
|
|
|
|
212 |
raise
|
213 |
|
214 |
-
@spaces.GPU()
|
215 |
-
def
|
216 |
-
"""Transcribe an audio or video file."""
|
|
|
|
|
|
|
|
|
217 |
try:
|
218 |
model_manager.check_whisper_initialized()
|
219 |
-
|
220 |
-
if
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
else:
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
try:
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
return transcription
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
except Exception as e:
|
247 |
-
logger.error(f"
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
-
@lru_cache(maxsize=
|
251 |
def read_document(document_path):
|
252 |
-
"""Read the content of a document."""
|
253 |
try:
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
doc = fitz.open(document_path)
|
256 |
-
|
257 |
-
|
|
|
|
|
258 |
doc = docx.Document(document_path)
|
259 |
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
260 |
-
elif
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
else:
|
|
|
265 |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
|
|
|
|
|
|
|
266 |
except Exception as e:
|
267 |
-
logger.error(f"Error reading document: {str(e)}")
|
|
|
268 |
return f"Error reading document: {str(e)}"
|
269 |
|
270 |
-
@lru_cache(maxsize=
|
271 |
def read_url(url):
|
272 |
-
"""Read the content of a URL."""
|
273 |
-
if not url or url.strip()
|
274 |
-
|
275 |
-
|
|
|
276 |
try:
|
|
|
277 |
headers = {
|
278 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
279 |
}
|
280 |
-
|
281 |
-
response.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
soup = BeautifulSoup(response.content, 'html.parser')
|
283 |
-
|
284 |
-
|
|
|
285 |
element.extract()
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
if main_content:
|
289 |
text = main_content.get_text(separator='\n', strip=True)
|
290 |
else:
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
except Exception as e:
|
298 |
-
logger.error(f"Error
|
299 |
-
|
|
|
300 |
|
301 |
-
def
|
302 |
-
"""Process social media
|
303 |
-
if not url or url.strip()
|
|
|
304 |
return None
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
try:
|
307 |
text_content = read_url(url)
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
return {
|
315 |
-
"text": text_content,
|
316 |
-
"video":
|
317 |
}
|
318 |
-
|
319 |
-
logger.error(f"
|
320 |
-
return None
|
|
|
321 |
|
322 |
-
@spaces.GPU()
|
323 |
def generate_news(instructions, facts, size, tone, *args):
|
324 |
-
"""Generate a news article based on provided data"""
|
|
|
|
|
325 |
try:
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
329 |
size = 250
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
knowledge_base = {
|
334 |
-
"instructions": instructions or "",
|
335 |
-
"facts": facts or "",
|
336 |
"document_content": [],
|
337 |
"audio_data": [],
|
338 |
"url_content": [],
|
339 |
"social_content": []
|
340 |
}
|
|
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
})
|
380 |
-
|
381 |
-
logger.info("Processing social media content...")
|
382 |
-
for i in range(0, len(social_urls), 3):
|
383 |
-
if i+2 < len(social_urls):
|
384 |
-
social_url, social_name, social_context = social_urls[i:i+3]
|
385 |
-
if social_url and isinstance(social_url, str) and social_url.strip():
|
386 |
-
social_content = process_social_content(social_url)
|
387 |
-
if social_content:
|
388 |
-
knowledge_base["social_content"].append({
|
389 |
"url": social_url,
|
390 |
-
"name": social_name
|
391 |
-
"context": social_context
|
392 |
-
"text":
|
393 |
-
"
|
394 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
-
transcriptions_text = ""
|
397 |
-
raw_transcriptions = ""
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
|
|
407 |
for idx, data in enumerate(knowledge_base["social_content"]):
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
if len(url_content) > 1000:
|
432 |
-
url_excerpt = url_content[:1000] + "... [content continues]"
|
433 |
-
else:
|
434 |
-
url_excerpt = url_content
|
435 |
-
url_summaries.append(f"[URL {idx+1}]: {url_excerpt}")
|
436 |
-
|
437 |
-
url_content = "\n\n".join(url_summaries)
|
438 |
|
439 |
-
|
|
|
440 |
|
441 |
-
Instructions: {knowledge_base["instructions"]}
|
|
|
442 |
|
443 |
-
|
444 |
|
445 |
-
|
446 |
-
{
|
447 |
|
448 |
-
|
449 |
-
{
|
450 |
|
451 |
-
Use these
|
452 |
-
{
|
453 |
|
454 |
-
|
455 |
-
-
|
456 |
-
- Write a 15-word hook that complements the title
|
457 |
-
- Write the body
|
458 |
-
-
|
459 |
-
-
|
460 |
-
- Use
|
461 |
-
-
|
462 |
-
- Do
|
463 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
|
465 |
try:
|
466 |
-
|
467 |
-
|
468 |
-
max_length = min(len(prompt.split()) + size * 2, 2048)
|
469 |
-
|
470 |
-
inputs = model_manager.tokenizer(
|
471 |
prompt,
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
max_new_tokens = size + 100,
|
481 |
-
temperature = 0.7,
|
482 |
-
do_sample = True,
|
483 |
-
pad_token_id = model_manager.tokenizer.eos_token_id,
|
484 |
)
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
else:
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
except Exception as gen_error:
|
500 |
-
logger.error(f"Error
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
except Exception as e:
|
506 |
-
|
|
|
|
|
|
|
507 |
try:
|
508 |
model_manager.reset_models(force=True)
|
509 |
except Exception as reset_error:
|
510 |
-
logger.error(f"Failed to reset models: {str(reset_error)}")
|
511 |
-
|
|
|
|
|
|
|
512 |
|
513 |
def create_demo():
|
|
|
514 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
515 |
gr.Markdown("# π° NewsIA - AI News Generator")
|
516 |
-
gr.Markdown("Create professional news articles from multiple sources.")
|
517 |
-
|
|
|
|
|
|
|
518 |
with gr.Row():
|
519 |
with gr.Column(scale=2):
|
520 |
instructions = gr.Textbox(
|
521 |
-
label="News
|
522 |
-
placeholder="Enter specific instructions for news
|
523 |
-
lines=2
|
|
|
524 |
)
|
|
|
|
|
525 |
facts = gr.Textbox(
|
526 |
-
label="
|
527 |
-
placeholder="Describe the most important facts
|
528 |
-
lines=4
|
|
|
529 |
)
|
530 |
-
|
|
|
531 |
with gr.Row():
|
532 |
-
|
533 |
label="Approximate Length (words)",
|
534 |
minimum=100,
|
535 |
-
maximum=
|
536 |
value=250,
|
537 |
step=50
|
538 |
)
|
539 |
-
|
540 |
-
|
541 |
-
|
|
|
|
|
542 |
value="neutral"
|
543 |
)
|
|
|
544 |
|
545 |
with gr.Column(scale=3):
|
546 |
-
inputs_list = []
|
547 |
-
inputs_list.extend([instructions, facts, size, tone])
|
548 |
-
|
549 |
with gr.Tabs():
|
550 |
with gr.TabItem("π Documents"):
|
551 |
-
documents
|
|
|
552 |
for i in range(1, 6):
|
553 |
-
|
554 |
label=f"Document {i}",
|
555 |
-
file_types=["pdf", "docx", "xlsx", "csv"],
|
556 |
-
file_count="single"
|
557 |
)
|
558 |
-
|
559 |
-
|
560 |
|
561 |
with gr.TabItem("π Audio/Video"):
|
562 |
-
|
|
|
|
|
563 |
with gr.Group():
|
564 |
gr.Markdown(f"**Source {i}**")
|
565 |
-
|
566 |
-
label=f"Audio/Video {i}",
|
567 |
file_types=["audio", "video"]
|
568 |
)
|
569 |
with gr.Row():
|
570 |
-
|
571 |
-
label="Name",
|
572 |
-
placeholder="
|
|
|
573 |
)
|
574 |
-
|
575 |
-
label="Position
|
576 |
-
placeholder="
|
|
|
577 |
)
|
578 |
-
|
|
|
|
|
|
|
579 |
|
580 |
with gr.TabItem("π URLs"):
|
581 |
-
|
582 |
-
|
|
|
|
|
583 |
label=f"URL {i}",
|
584 |
-
placeholder="https
|
|
|
585 |
)
|
586 |
-
|
|
|
587 |
|
588 |
with gr.TabItem("π± Social Media"):
|
589 |
-
|
|
|
|
|
590 |
with gr.Group():
|
591 |
-
gr.Markdown(f"**Social Media {i}**")
|
592 |
-
|
593 |
-
label="URL",
|
594 |
-
placeholder="https
|
|
|
595 |
)
|
596 |
with gr.Row():
|
597 |
-
|
598 |
-
label="Account/
|
599 |
-
placeholder="
|
|
|
600 |
)
|
601 |
-
|
602 |
-
label="Context",
|
603 |
-
placeholder="
|
|
|
604 |
)
|
605 |
-
|
|
|
|
|
|
|
|
|
606 |
|
607 |
with gr.Row():
|
608 |
-
|
609 |
-
|
610 |
|
611 |
with gr.Tabs():
|
612 |
-
with gr.TabItem("π Generated News"):
|
613 |
news_output = gr.Textbox(
|
614 |
-
label="News
|
615 |
-
lines=
|
616 |
-
show_copy_button=True
|
|
|
617 |
)
|
618 |
-
|
619 |
-
with gr.TabItem("ποΈ Transcriptions"):
|
620 |
transcriptions_output = gr.Textbox(
|
621 |
-
label="
|
622 |
-
lines=
|
623 |
-
show_copy_button=True
|
|
|
624 |
)
|
625 |
|
626 |
-
|
|
|
|
|
|
|
|
|
|
|
627 |
fn=generate_news,
|
628 |
-
inputs=
|
629 |
-
outputs=
|
630 |
)
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
)
|
640 |
|
|
|
|
|
|
|
641 |
return demo
|
642 |
|
643 |
if __name__ == "__main__":
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import whisper
|
11 |
import subprocess
|
12 |
from pydub import AudioSegment
|
13 |
+
import fitz # PyMuPDF
|
14 |
import docx
|
15 |
import yt_dlp
|
16 |
from functools import lru_cache
|
17 |
import gc
|
18 |
import time
|
19 |
from huggingface_hub import login
|
20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
21 |
+
import traceback # For detailed error logging
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(
|
|
|
30 |
# Login to Hugging Face Hub if token is available
|
31 |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
|
32 |
if HUGGINGFACE_TOKEN:
|
33 |
+
try:
|
34 |
+
login(token=HUGGINGFACE_TOKEN)
|
35 |
+
logger.info("Successfully logged in to Hugging Face Hub.")
|
36 |
+
except Exception as e:
|
37 |
+
logger.error(f"Failed to login to Hugging Face Hub: {e}")
|
38 |
|
39 |
class ModelManager:
|
40 |
_instance = None
|
41 |
+
|
42 |
def __new__(cls):
|
43 |
if cls._instance is None:
|
44 |
cls._instance = super(ModelManager, cls).__new__(cls)
|
45 |
cls._instance._initialized = False
|
46 |
return cls._instance
|
47 |
+
|
48 |
def __init__(self):
|
49 |
if not self._initialized:
|
50 |
self.tokenizer = None
|
51 |
self.model = None
|
52 |
+
self.text_pipeline = None # Renamed for clarity
|
53 |
self.whisper_model = None
|
54 |
self._initialized = True
|
55 |
self.last_used = time.time()
|
56 |
+
self.llm_loading = False
|
57 |
+
self.whisper_loading = False
|
58 |
+
|
59 |
+
@spaces.GPU(duration=120) # Increased duration for potentially long loads
|
60 |
def initialize_llm(self):
|
61 |
+
"""Initialize LLM model with standard transformers"""
|
62 |
+
if self.llm_loading:
|
63 |
+
logger.info("LLM initialization already in progress.")
|
64 |
+
return True # Assume it will succeed or fail elsewhere
|
65 |
+
if self.tokenizer and self.model and self.text_pipeline:
|
66 |
+
logger.info("LLM already initialized.")
|
67 |
+
self.last_used = time.time()
|
68 |
+
return True
|
69 |
+
|
70 |
+
self.llm_loading = True
|
71 |
try:
|
72 |
+
# Use small model for ZeroGPU compatibility
|
73 |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
74 |
+
|
75 |
+
logger.info("Loading LLM tokenizer...")
|
76 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
77 |
+
MODEL_NAME,
|
78 |
+
token=HUGGINGFACE_TOKEN,
|
79 |
+
use_fast=True
|
80 |
+
)
|
81 |
+
|
82 |
+
if self.tokenizer.pad_token is None:
|
83 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
84 |
+
|
85 |
+
# Basic memory settings for ZeroGPU
|
86 |
+
logger.info("Loading LLM model...")
|
87 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
88 |
+
MODEL_NAME,
|
89 |
+
token=HUGGINGFACE_TOKEN,
|
90 |
+
device_map="auto",
|
91 |
+
torch_dtype=torch.float16,
|
92 |
+
low_cpu_mem_usage=True,
|
93 |
+
# Optimizations for ZeroGPU
|
94 |
+
# max_memory={0: "4GB"}, # Removed for better auto handling initially
|
95 |
+
offload_folder="offload",
|
96 |
+
offload_state_dict=True
|
97 |
)
|
98 |
+
|
99 |
+
# Create text generation pipeline
|
100 |
+
logger.info("Creating LLM text generation pipeline...")
|
101 |
+
self.text_pipeline = pipeline(
|
102 |
+
"text-generation",
|
103 |
+
model=self.model,
|
104 |
+
tokenizer=self.tokenizer,
|
105 |
+
torch_dtype=torch.float16,
|
106 |
+
device_map="auto",
|
107 |
+
max_length=1024 # Default max length
|
|
|
|
|
|
|
108 |
)
|
109 |
+
|
110 |
+
logger.info("LLM initialized successfully")
|
111 |
self.last_used = time.time()
|
112 |
+
self.llm_loading = False
|
113 |
return True
|
114 |
+
|
115 |
except Exception as e:
|
116 |
logger.error(f"Error initializing LLM: {str(e)}")
|
117 |
+
logger.error(traceback.format_exc()) # Log full traceback
|
118 |
+
# Reset partially loaded components
|
119 |
+
self.tokenizer = None
|
120 |
+
self.model = None
|
121 |
+
self.text_pipeline = None
|
122 |
+
if torch.cuda.is_available():
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
gc.collect()
|
125 |
+
self.llm_loading = False
|
126 |
+
raise # Re-raise the exception to signal failure
|
127 |
|
128 |
+
@spaces.GPU(duration=120) # Increased duration
|
129 |
def initialize_whisper(self):
|
130 |
+
"""Initialize Whisper model for audio transcription"""
|
131 |
+
if self.whisper_loading:
|
132 |
+
logger.info("Whisper initialization already in progress.")
|
133 |
+
return True
|
134 |
+
if self.whisper_model:
|
135 |
+
logger.info("Whisper already initialized.")
|
136 |
+
self.last_used = time.time()
|
137 |
+
return True
|
138 |
+
|
139 |
+
self.whisper_loading = True
|
140 |
try:
|
141 |
logger.info("Loading Whisper model...")
|
142 |
+
# Using tiny model for efficiency but can be changed based on needs
|
143 |
+
# Specify weights_only=True to address the FutureWarning
|
144 |
+
# Note: Whisper's load_model might not directly support weights_only yet.
|
145 |
+
# If it errors, remove the weights_only=True. The warning is mainly informative.
|
146 |
+
# Let's attempt without weights_only first as whisper might handle it internally
|
147 |
self.whisper_model = whisper.load_model(
|
148 |
+
"tiny", # Consider "base" for better accuracy if "tiny" struggles
|
149 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
150 |
+
download_root="/tmp/whisper" # Use persistent storage if available/needed
|
|
|
151 |
)
|
152 |
logger.info("Whisper model initialized successfully")
|
153 |
self.last_used = time.time()
|
154 |
+
self.whisper_loading = False
|
155 |
return True
|
156 |
except Exception as e:
|
157 |
logger.error(f"Error initializing Whisper: {str(e)}")
|
158 |
+
logger.error(traceback.format_exc())
|
159 |
+
self.whisper_model = None
|
160 |
+
if torch.cuda.is_available():
|
161 |
+
torch.cuda.empty_cache()
|
162 |
+
gc.collect()
|
163 |
+
self.whisper_loading = False
|
164 |
raise
|
165 |
|
166 |
def check_llm_initialized(self):
|
167 |
"""Check if LLM is initialized and initialize if needed"""
|
168 |
+
if self.tokenizer is None or self.model is None or self.text_pipeline is None:
|
169 |
logger.info("LLM not initialized, initializing...")
|
170 |
+
if not self.llm_loading: # Prevent re-entry if already loading
|
171 |
+
self.initialize_llm()
|
172 |
+
else:
|
173 |
+
logger.info("LLM initialization is already in progress by another request.")
|
174 |
+
# Optional: Wait a bit for the other process to finish
|
175 |
+
time.sleep(5)
|
176 |
+
if self.tokenizer is None or self.model is None or self.text_pipeline is None:
|
177 |
+
raise RuntimeError("LLM initialization timed out or failed.")
|
178 |
self.last_used = time.time()
|
179 |
+
|
180 |
def check_whisper_initialized(self):
|
181 |
"""Check if Whisper model is initialized and initialize if needed"""
|
182 |
if self.whisper_model is None:
|
183 |
logger.info("Whisper model not initialized, initializing...")
|
184 |
+
if not self.whisper_loading: # Prevent re-entry
|
185 |
+
self.initialize_whisper()
|
186 |
+
else:
|
187 |
+
logger.info("Whisper initialization is already in progress by another request.")
|
188 |
+
time.sleep(5)
|
189 |
+
if self.whisper_model is None:
|
190 |
+
raise RuntimeError("Whisper initialization timed out or failed.")
|
191 |
self.last_used = time.time()
|
192 |
+
|
193 |
def reset_models(self, force=False):
|
194 |
"""Reset models to free memory if they haven't been used recently"""
|
195 |
current_time = time.time()
|
196 |
+
# Only reset if forced or models haven't been used for 10 minutes (600 seconds)
|
197 |
+
if force or (current_time - self.last_used > 600):
|
198 |
try:
|
199 |
logger.info("Resetting models to free memory...")
|
200 |
+
|
201 |
+
# Check and delete attributes safely
|
202 |
+
if hasattr(self, 'model') and self.model is not None:
|
203 |
del self.model
|
204 |
+
self.model = None
|
205 |
+
logger.info("LLM model deleted.")
|
206 |
+
|
207 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
208 |
del self.tokenizer
|
209 |
+
self.tokenizer = None
|
210 |
+
logger.info("LLM tokenizer deleted.")
|
211 |
+
|
212 |
+
if hasattr(self, 'text_pipeline') and self.text_pipeline is not None:
|
213 |
+
del self.text_pipeline
|
214 |
+
self.text_pipeline = None
|
215 |
+
logger.info("LLM pipeline deleted.")
|
216 |
+
|
217 |
+
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
|
218 |
del self.whisper_model
|
219 |
+
self.whisper_model = None
|
220 |
+
logger.info("Whisper model deleted.")
|
221 |
+
|
222 |
+
# Explicitly clear CUDA cache and collect garbage
|
|
|
223 |
if torch.cuda.is_available():
|
224 |
torch.cuda.empty_cache()
|
225 |
+
# torch.cuda.synchronize() # May not be needed and can slow down
|
226 |
+
logger.info("CUDA cache cleared.")
|
227 |
+
|
228 |
gc.collect()
|
229 |
+
logger.info("Garbage collected. Models reset successfully.")
|
230 |
+
self._initialized = False # Mark as uninitialized so they reload on next use
|
231 |
+
|
232 |
except Exception as e:
|
233 |
logger.error(f"Error resetting models: {str(e)}")
|
234 |
+
logger.error(traceback.format_exc())
|
235 |
|
236 |
+
# Create global model manager instance
|
237 |
model_manager = ModelManager()
|
238 |
|
239 |
+
@lru_cache(maxsize=16) # Reduced cache size slightly
|
240 |
def download_social_media_video(url):
|
241 |
+
"""Download audio from a social media video URL."""
|
242 |
+
temp_dir = tempfile.mkdtemp()
|
243 |
+
output_template = os.path.join(temp_dir, '%(id)s.%(ext)s')
|
244 |
+
|
245 |
ydl_opts = {
|
246 |
'format': 'bestaudio/best',
|
247 |
'postprocessors': [{
|
248 |
'key': 'FFmpegExtractAudio',
|
249 |
'preferredcodec': 'mp3',
|
250 |
+
'preferredquality': '192', # Standard quality
|
251 |
}],
|
252 |
+
'outtmpl': output_template,
|
253 |
+
'quiet': True,
|
254 |
+
'no_warnings': True,
|
255 |
+
'nocheckcertificate': True, # Sometimes needed for tricky sites
|
256 |
+
'retries': 3, # Add retries
|
257 |
+
'socket_timeout': 15, # Timeout
|
258 |
}
|
259 |
try:
|
260 |
+
logger.info(f"Attempting to download audio from: {url}")
|
261 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
262 |
info_dict = ydl.extract_info(url, download=True)
|
263 |
+
# Construct the expected final filename after postprocessing
|
264 |
+
audio_file = os.path.join(temp_dir, f"{info_dict['id']}.mp3")
|
265 |
+
if not os.path.exists(audio_file):
|
266 |
+
# Fallback if filename doesn't match exactly (e.g., webm -> mp3)
|
267 |
+
found_files = [f for f in os.listdir(temp_dir) if f.endswith('.mp3')]
|
268 |
+
if found_files:
|
269 |
+
audio_file = os.path.join(temp_dir, found_files[0])
|
270 |
+
else:
|
271 |
+
raise FileNotFoundError(f"Could not find downloaded MP3 in {temp_dir}")
|
272 |
+
|
273 |
+
logger.info(f"Audio downloaded successfully: {audio_file}")
|
274 |
+
# Read the file content to return, as the temp dir might be cleaned up
|
275 |
+
with open(audio_file, 'rb') as f:
|
276 |
+
audio_content = f.read()
|
277 |
+
|
278 |
+
# Clean up the temporary directory and file
|
279 |
+
try:
|
280 |
+
os.remove(audio_file)
|
281 |
+
os.rmdir(temp_dir)
|
282 |
+
except OSError as e:
|
283 |
+
logger.warning(f"Could not completely clean up temp download files: {e}")
|
284 |
+
|
285 |
+
# Save the content to a new temporary file that Gradio can handle
|
286 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_output_file:
|
287 |
+
temp_output_file.write(audio_content)
|
288 |
+
final_path = temp_output_file.name
|
289 |
+
logger.info(f"Audio saved to temporary file: {final_path}")
|
290 |
+
return final_path
|
291 |
+
|
292 |
+
except yt_dlp.utils.DownloadError as e:
|
293 |
+
logger.error(f"yt-dlp download error for {url}: {str(e)}")
|
294 |
+
# Clean up temp dir on error
|
295 |
+
try:
|
296 |
+
if os.path.exists(temp_dir):
|
297 |
+
import shutil
|
298 |
+
shutil.rmtree(temp_dir)
|
299 |
+
except Exception as cleanup_e:
|
300 |
+
logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
|
301 |
+
return None # Return None to indicate failure
|
302 |
except Exception as e:
|
303 |
+
logger.error(f"Unexpected error downloading video from {url}: {str(e)}")
|
304 |
+
logger.error(traceback.format_exc())
|
305 |
+
# Clean up temp dir on error
|
306 |
+
try:
|
307 |
+
if os.path.exists(temp_dir):
|
308 |
+
import shutil
|
309 |
+
shutil.rmtree(temp_dir)
|
310 |
+
except Exception as cleanup_e:
|
311 |
+
logger.warning(f"Error during cleanup after download failure: {cleanup_e}")
|
312 |
+
return None # Return None
|
313 |
|
314 |
+
def convert_video_to_audio(video_file_path):
|
315 |
"""Convert a video file to audio using ffmpeg directly."""
|
316 |
try:
|
317 |
+
# Create a temporary file path for the output MP3
|
318 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
319 |
+
output_file_path = temp_file.name
|
320 |
+
|
321 |
+
logger.info(f"Converting video '{video_file_path}' to audio '{output_file_path}'")
|
322 |
+
|
323 |
+
# Use ffmpeg directly via subprocess
|
324 |
command = [
|
325 |
+
"ffmpeg",
|
326 |
+
"-i", video_file_path,
|
327 |
+
"-vn", # No video
|
328 |
+
"-acodec", "libmp3lame", # Specify MP3 codec
|
329 |
+
"-ab", "192k", # Audio bitrate
|
330 |
+
"-ar", "44100", # Audio sample rate
|
331 |
+
"-ac", "2", # Stereo audio
|
332 |
+
output_file_path,
|
333 |
+
"-y", # Overwrite output file if it exists
|
334 |
+
"-loglevel", "error" # Suppress verbose ffmpeg output
|
335 |
]
|
336 |
+
|
337 |
+
process = subprocess.run(command, check=True, capture_output=True, text=True)
|
338 |
+
logger.info(f"ffmpeg conversion successful for {video_file_path}.")
|
339 |
+
logger.debug(f"ffmpeg stdout: {process.stdout}")
|
340 |
+
logger.debug(f"ffmpeg stderr: {process.stderr}")
|
341 |
+
|
342 |
+
|
343 |
+
# Verify output file exists and has size
|
344 |
+
if not os.path.exists(output_file_path) or os.path.getsize(output_file_path) == 0:
|
345 |
+
raise RuntimeError(f"ffmpeg conversion failed: Output file '{output_file_path}' not created or is empty.")
|
346 |
+
|
347 |
+
logger.info(f"Video converted to audio: {output_file_path}")
|
348 |
+
return output_file_path
|
349 |
+
except subprocess.CalledProcessError as e:
|
350 |
+
logger.error(f"ffmpeg command failed with exit code {e.returncode}")
|
351 |
+
logger.error(f"ffmpeg stderr: {e.stderr}")
|
352 |
+
logger.error(f"ffmpeg stdout: {e.stdout}")
|
353 |
+
# Clean up potentially empty output file
|
354 |
+
if os.path.exists(output_file_path):
|
355 |
+
os.remove(output_file_path)
|
356 |
+
raise RuntimeError(f"ffmpeg conversion failed: {e.stderr}") from e
|
357 |
except Exception as e:
|
358 |
+
logger.error(f"Error converting video '{video_file_path}': {str(e)}")
|
359 |
+
logger.error(traceback.format_exc())
|
360 |
+
# Clean up potentially created output file
|
361 |
+
if 'output_file_path' in locals() and os.path.exists(output_file_path):
|
362 |
+
os.remove(output_file_path)
|
363 |
+
raise # Re-raise the exception
|
364 |
|
365 |
+
def preprocess_audio(input_audio_path):
|
366 |
+
"""Preprocess the audio file (e.g., normalize volume)."""
|
367 |
try:
|
368 |
+
logger.info(f"Preprocessing audio file: {input_audio_path}")
|
369 |
+
audio = AudioSegment.from_file(input_audio_path)
|
370 |
+
|
371 |
+
# Apply normalization (optional, adjust target dBFS as needed)
|
372 |
+
# Target loudness: -20 dBFS. Adjust gain based on current loudness.
|
373 |
+
# change_in_dBFS = -20.0 - audio.dBFS
|
374 |
+
# audio = audio.apply_gain(change_in_dBFS)
|
375 |
+
|
376 |
+
# Export to a new temporary file
|
377 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
|
378 |
+
output_path = temp_file.name
|
379 |
+
audio.export(output_path, format="mp3")
|
380 |
+
|
381 |
+
logger.info(f"Audio preprocessed and saved to: {output_path}")
|
382 |
+
return output_path
|
383 |
except Exception as e:
|
384 |
+
logger.error(f"Error preprocessing audio '{input_audio_path}': {str(e)}")
|
385 |
+
logger.error(traceback.format_exc())
|
386 |
+
# Return original path if preprocessing fails? Or raise error?
|
387 |
+
# Let's raise the error to signal failure clearly.
|
388 |
raise
|
389 |
|
390 |
+
@spaces.GPU(duration=300) # Allow more time for transcription
|
391 |
+
def transcribe_audio_or_video(file_input):
|
392 |
+
"""Transcribe an audio or video file (local path or Gradio File object)."""
|
393 |
+
audio_file_to_transcribe = None
|
394 |
+
original_input_path = None
|
395 |
+
temp_files_to_clean = []
|
396 |
+
|
397 |
try:
|
398 |
model_manager.check_whisper_initialized()
|
399 |
+
|
400 |
+
if file_input is None:
|
401 |
+
logger.info("No file input provided for transcription.")
|
402 |
+
return "" # Return empty string for None input
|
403 |
+
|
404 |
+
# Determine input type and get file path
|
405 |
+
if isinstance(file_input, str): # Input is a path
|
406 |
+
original_input_path = file_input
|
407 |
+
logger.info(f"Processing path input: {original_input_path}")
|
408 |
+
if not os.path.exists(original_input_path):
|
409 |
+
logger.error(f"Input file path does not exist: {original_input_path}")
|
410 |
+
raise FileNotFoundError(f"Input file not found: {original_input_path}")
|
411 |
+
input_path = original_input_path
|
412 |
+
elif hasattr(file_input, 'name'): # Input is a Gradio File object
|
413 |
+
original_input_path = file_input.name
|
414 |
+
logger.info(f"Processing Gradio file input: {original_input_path}")
|
415 |
+
input_path = original_input_path # Gradio usually provides a temp path
|
416 |
else:
|
417 |
+
logger.error(f"Unsupported input type for transcription: {type(file_input)}")
|
418 |
+
raise TypeError("Invalid input type for transcription. Expected file path or Gradio File object.")
|
419 |
+
|
420 |
+
file_extension = os.path.splitext(input_path)[1].lower()
|
421 |
+
|
422 |
+
# Check if it's a video file that needs conversion
|
423 |
+
if file_extension in ['.mp4', '.avi', '.mov', '.mkv', '.webm']:
|
424 |
+
logger.info(f"Detected video file ({file_extension}), converting to audio...")
|
425 |
+
converted_audio_path = convert_video_to_audio(input_path)
|
426 |
+
temp_files_to_clean.append(converted_audio_path)
|
427 |
+
audio_file_to_process = converted_audio_path
|
428 |
+
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
|
429 |
+
logger.info(f"Detected audio file ({file_extension}).")
|
430 |
+
audio_file_to_process = input_path
|
431 |
+
else:
|
432 |
+
logger.error(f"Unsupported file extension for transcription: {file_extension}")
|
433 |
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
434 |
+
|
435 |
+
# Preprocess the audio (optional, could be skipped if causing issues)
|
436 |
try:
|
437 |
+
preprocessed_audio_path = preprocess_audio(audio_file_to_process)
|
438 |
+
# If preprocessing creates a new file different from the input, add it to cleanup
|
439 |
+
if preprocessed_audio_path != audio_file_to_process:
|
440 |
+
temp_files_to_clean.append(preprocessed_audio_path)
|
441 |
+
audio_file_to_transcribe = preprocessed_audio_path
|
442 |
+
except Exception as preprocess_err:
|
443 |
+
logger.warning(f"Audio preprocessing failed: {preprocess_err}. Using original/converted audio.")
|
444 |
+
audio_file_to_transcribe = audio_file_to_process # Fallback
|
445 |
+
|
446 |
+
logger.info(f"Transcribing audio file: {audio_file_to_transcribe}")
|
447 |
+
if not os.path.exists(audio_file_to_transcribe):
|
448 |
+
raise FileNotFoundError(f"Audio file to transcribe not found: {audio_file_to_transcribe}")
|
449 |
+
|
450 |
+
# Perform transcription
|
451 |
+
with torch.inference_mode(): # Ensure inference mode for efficiency
|
452 |
+
# Use fp16 if available on CUDA
|
453 |
+
use_fp16 = torch.cuda.is_available()
|
454 |
+
result = model_manager.whisper_model.transcribe(
|
455 |
+
audio_file_to_transcribe,
|
456 |
+
fp16=use_fp16
|
457 |
+
)
|
458 |
+
if not result:
|
459 |
+
raise RuntimeError("Transcription failed to produce results")
|
460 |
+
|
461 |
+
transcription = result.get("text", "Error: Transcription result empty")
|
462 |
+
# Limit transcription length shown in logs
|
463 |
+
log_transcription = (transcription[:100] + '...') if len(transcription) > 100 else transcription
|
464 |
+
logger.info(f"Transcription completed: {log_transcription}")
|
465 |
+
|
466 |
return transcription
|
467 |
+
|
468 |
+
except FileNotFoundError as e:
|
469 |
+
logger.error(f"File not found error during transcription: {e}")
|
470 |
+
return f"Error: Input file not found ({e})"
|
471 |
+
except ValueError as e:
|
472 |
+
logger.error(f"Value error during transcription: {e}")
|
473 |
+
return f"Error: Unsupported file type ({e})"
|
474 |
+
except TypeError as e:
|
475 |
+
logger.error(f"Type error during transcription setup: {e}")
|
476 |
+
return f"Error: Invalid input provided ({e})"
|
477 |
+
except RuntimeError as e:
|
478 |
+
logger.error(f"Runtime error during transcription: {e}")
|
479 |
+
logger.error(traceback.format_exc())
|
480 |
+
return f"Error during processing: {e}"
|
481 |
except Exception as e:
|
482 |
+
logger.error(f"Unexpected error during transcription: {str(e)}")
|
483 |
+
logger.error(traceback.format_exc())
|
484 |
+
return f"Error processing the file: An unexpected error occurred."
|
485 |
+
|
486 |
+
finally:
|
487 |
+
# Clean up all temporary files created during the process
|
488 |
+
for temp_file in temp_files_to_clean:
|
489 |
+
try:
|
490 |
+
if os.path.exists(temp_file):
|
491 |
+
os.remove(temp_file)
|
492 |
+
logger.info(f"Cleaned up temporary file: {temp_file}")
|
493 |
+
except Exception as e:
|
494 |
+
logger.warning(f"Could not remove temporary file {temp_file}: {str(e)}")
|
495 |
+
# Optionally reset models if idle (might be too aggressive here)
|
496 |
+
# model_manager.reset_models()
|
497 |
|
498 |
+
@lru_cache(maxsize=16)
|
499 |
def read_document(document_path):
|
500 |
+
"""Read the content of a document (PDF, DOCX, XLSX, CSV)."""
|
501 |
try:
|
502 |
+
logger.info(f"Reading document: {document_path}")
|
503 |
+
if not os.path.exists(document_path):
|
504 |
+
raise FileNotFoundError(f"Document not found: {document_path}")
|
505 |
+
|
506 |
+
file_extension = os.path.splitext(document_path)[1].lower()
|
507 |
+
|
508 |
+
if file_extension == ".pdf":
|
509 |
doc = fitz.open(document_path)
|
510 |
+
text = "\n".join([page.get_text() for page in doc])
|
511 |
+
doc.close()
|
512 |
+
return text
|
513 |
+
elif file_extension == ".docx":
|
514 |
doc = docx.Document(document_path)
|
515 |
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
516 |
+
elif file_extension in (".xlsx", ".xls"):
|
517 |
+
# Read all sheets and combine
|
518 |
+
xls = pd.ExcelFile(document_path)
|
519 |
+
text = ""
|
520 |
+
for sheet_name in xls.sheet_names:
|
521 |
+
df = pd.read_excel(xls, sheet_name=sheet_name)
|
522 |
+
text += f"--- Sheet: {sheet_name} ---\n{df.to_string()}\n\n"
|
523 |
+
return text.strip()
|
524 |
+
elif file_extension == ".csv":
|
525 |
+
# Try detecting separator
|
526 |
+
try:
|
527 |
+
df = pd.read_csv(document_path)
|
528 |
+
except pd.errors.ParserError:
|
529 |
+
logger.warning(f"Could not parse CSV {document_path} with default comma separator, trying semicolon.")
|
530 |
+
df = pd.read_csv(document_path, sep=';')
|
531 |
+
return df.to_string()
|
532 |
else:
|
533 |
+
logger.warning(f"Unsupported document type: {file_extension}")
|
534 |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
|
535 |
+
except FileNotFoundError as e:
|
536 |
+
logger.error(f"Error reading document: {e}")
|
537 |
+
return f"Error: Document file not found at {document_path}"
|
538 |
except Exception as e:
|
539 |
+
logger.error(f"Error reading document {document_path}: {str(e)}")
|
540 |
+
logger.error(traceback.format_exc())
|
541 |
return f"Error reading document: {str(e)}"
|
542 |
|
543 |
+
@lru_cache(maxsize=16)
|
544 |
def read_url(url):
|
545 |
+
"""Read the main textual content of a URL."""
|
546 |
+
if not url or not url.strip().startswith('http'):
|
547 |
+
logger.info(f"Invalid or empty URL provided: '{url}'")
|
548 |
+
return "" # Return empty for invalid or empty URLs
|
549 |
+
|
550 |
try:
|
551 |
+
logger.info(f"Reading URL: {url}")
|
552 |
headers = {
|
553 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
554 |
}
|
555 |
+
# Increased timeout
|
556 |
+
response = requests.get(url, headers=headers, timeout=20, allow_redirects=True)
|
557 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
558 |
+
|
559 |
+
# Check content type - proceed only if likely HTML/text
|
560 |
+
content_type = response.headers.get('content-type', '').lower()
|
561 |
+
if not ('html' in content_type or 'text' in content_type):
|
562 |
+
logger.warning(f"URL {url} has non-text content type: {content_type}. Skipping.")
|
563 |
+
return f"Error: URL content type ({content_type}) is not text/html."
|
564 |
+
|
565 |
soup = BeautifulSoup(response.content, 'html.parser')
|
566 |
+
|
567 |
+
# Remove non-content elements like scripts, styles, nav, footers etc.
|
568 |
+
for element in soup(["script", "style", "meta", "noscript", "iframe", "header", "footer", "nav", "aside", "form", "button"]):
|
569 |
element.extract()
|
570 |
+
|
571 |
+
# Attempt to find main content area (common tags/attributes)
|
572 |
+
main_content = (
|
573 |
+
soup.find("main") or
|
574 |
+
soup.find("article") or
|
575 |
+
soup.find("div", class_=["content", "main", "post-content", "entry-content", "article-body"]) or
|
576 |
+
soup.find("div", id=["content", "main", "article"])
|
577 |
+
)
|
578 |
+
|
579 |
if main_content:
|
580 |
text = main_content.get_text(separator='\n', strip=True)
|
581 |
else:
|
582 |
+
# Fallback to body if no specific main content found
|
583 |
+
body = soup.find("body")
|
584 |
+
if body:
|
585 |
+
text = body.get_text(separator='\n', strip=True)
|
586 |
+
else: # Very basic fallback
|
587 |
+
text = soup.get_text(separator='\n', strip=True)
|
588 |
+
|
589 |
+
# Clean up whitespace: replace multiple newlines/spaces with single ones
|
590 |
+
text = '\n'.join([line.strip() for line in text.split('\n') if line.strip()])
|
591 |
+
text = ' '.join(text.split()) # Consolidate spaces within lines
|
592 |
+
|
593 |
+
if not text:
|
594 |
+
logger.warning(f"Could not extract meaningful text from URL: {url}")
|
595 |
+
return "Error: Could not extract text content from URL."
|
596 |
+
|
597 |
+
# Limit content size to avoid overwhelming the LLM
|
598 |
+
max_chars = 15000
|
599 |
+
if len(text) > max_chars:
|
600 |
+
logger.info(f"URL content truncated to {max_chars} characters.")
|
601 |
+
text = text[:max_chars] + "... [content truncated]"
|
602 |
+
|
603 |
+
return text
|
604 |
+
except requests.exceptions.RequestException as e:
|
605 |
+
logger.error(f"Error fetching URL {url}: {str(e)}")
|
606 |
+
return f"Error reading URL: Could not fetch content ({e})"
|
607 |
except Exception as e:
|
608 |
+
logger.error(f"Error parsing URL {url}: {str(e)}")
|
609 |
+
logger.error(traceback.format_exc())
|
610 |
+
return f"Error reading URL: Could not parse content ({e})"
|
611 |
|
612 |
+
def process_social_media_url(url):
|
613 |
+
"""Process a social media URL, attempting to get text and transcribe video/audio."""
|
614 |
+
if not url or not url.strip().startswith('http'):
|
615 |
+
logger.info(f"Invalid or empty social media URL: '{url}'")
|
616 |
return None
|
617 |
+
|
618 |
+
logger.info(f"Processing social media URL: {url}")
|
619 |
+
text_content = None
|
620 |
+
video_transcription = None
|
621 |
+
error_occurred = False
|
622 |
+
|
623 |
+
# 1. Try extracting text content using read_url (might work for some platforms/posts)
|
624 |
try:
|
625 |
text_content = read_url(url)
|
626 |
+
if text_content and text_content.startswith("Error:"):
|
627 |
+
logger.warning(f"Failed to read text content from social URL {url}: {text_content}")
|
628 |
+
text_content = None # Reset if it was an error message
|
629 |
+
except Exception as e:
|
630 |
+
logger.error(f"Error reading text content from social URL {url}: {e}")
|
631 |
+
error_occurred = True
|
632 |
|
633 |
+
# 2. Try downloading and transcribing potential video/audio content
|
634 |
+
downloaded_audio_path = None
|
635 |
+
try:
|
636 |
+
downloaded_audio_path = download_social_media_video(url)
|
637 |
+
if downloaded_audio_path:
|
638 |
+
logger.info(f"Audio downloaded from {url}, proceeding to transcription.")
|
639 |
+
video_transcription = transcribe_audio_or_video(downloaded_audio_path)
|
640 |
+
if video_transcription and video_transcription.startswith("Error"):
|
641 |
+
logger.warning(f"Transcription failed for audio from {url}: {video_transcription}")
|
642 |
+
video_transcription = None # Reset if it was an error
|
643 |
+
else:
|
644 |
+
logger.info(f"No downloadable audio/video found or download failed for URL: {url}")
|
645 |
+
except Exception as e:
|
646 |
+
logger.error(f"Error processing video content from social URL {url}: {e}")
|
647 |
+
logger.error(traceback.format_exc())
|
648 |
+
error_occurred = True
|
649 |
+
finally:
|
650 |
+
# Clean up downloaded file if it exists
|
651 |
+
if downloaded_audio_path and os.path.exists(downloaded_audio_path):
|
652 |
+
try:
|
653 |
+
os.remove(downloaded_audio_path)
|
654 |
+
logger.info(f"Cleaned up downloaded audio: {downloaded_audio_path}")
|
655 |
+
except Exception as e:
|
656 |
+
logger.warning(f"Failed to cleanup downloaded audio {downloaded_audio_path}: {e}")
|
657 |
+
|
658 |
+
# Return results only if some content was found or no critical error occurred
|
659 |
+
if text_content or video_transcription or not error_occurred:
|
660 |
return {
|
661 |
+
"text": text_content or "", # Ensure string type
|
662 |
+
"video": video_transcription or "" # Ensure string type
|
663 |
}
|
664 |
+
else:
|
665 |
+
logger.error(f"Failed to process social media URL {url} completely.")
|
666 |
+
return None # Indicate failure
|
667 |
+
|
668 |
|
669 |
+
@spaces.GPU(duration=300) # Allow more time for generation
|
670 |
def generate_news(instructions, facts, size, tone, *args):
|
671 |
+
"""Generate a news article based on provided data using an LLM."""
|
672 |
+
request_start_time = time.time()
|
673 |
+
logger.info("Received request to generate news.")
|
674 |
try:
|
675 |
+
# Ensure size is integer
|
676 |
+
try:
|
677 |
+
size = int(size) if size else 250 # Default size if None/empty
|
678 |
+
except ValueError:
|
679 |
+
logger.warning(f"Invalid size value '{size}', defaulting to 250.")
|
680 |
size = 250
|
681 |
+
|
682 |
+
# Check if models are initialized, load if necessary
|
683 |
+
model_manager.check_llm_initialized() # LLM is essential
|
684 |
+
# Whisper might be needed later, check/load if audio sources exist
|
685 |
+
|
686 |
+
# --- Argument Parsing ---
|
687 |
+
# The order *must* match the order components are added to inputs_list in create_demo
|
688 |
+
# Fixed inputs: instructions, facts, size, tone (already passed directly)
|
689 |
+
# Dynamic inputs from *args:
|
690 |
+
# Expected order in *args based on create_demo:
|
691 |
+
# 5 Documents, 15 Audio-related, 5 URLs, 9 Social-related
|
692 |
+
num_docs = 5
|
693 |
+
num_audio_sources = 5
|
694 |
+
num_audio_inputs_per_source = 3
|
695 |
+
num_urls = 5
|
696 |
+
num_social_sources = 3
|
697 |
+
num_social_inputs_per_source = 3
|
698 |
+
|
699 |
+
total_expected_args = num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls + (num_social_sources * num_social_inputs_per_source)
|
700 |
+
|
701 |
+
args_list = list(args)
|
702 |
+
# Pad args_list with None if fewer arguments were received than expected
|
703 |
+
args_list.extend([None] * (total_expected_args - len(args_list)))
|
704 |
+
|
705 |
+
# Slice arguments based on the expected order
|
706 |
+
doc_files = args_list[0:num_docs]
|
707 |
+
audio_inputs_flat = args_list[num_docs : num_docs + (num_audio_sources * num_audio_inputs_per_source)]
|
708 |
+
url_inputs = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) : num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls]
|
709 |
+
social_inputs_flat = args_list[num_docs + (num_audio_sources * num_audio_inputs_per_source) + num_urls : total_expected_args]
|
710 |
+
|
711 |
knowledge_base = {
|
712 |
+
"instructions": instructions or "No specific instructions provided.",
|
713 |
+
"facts": facts or "No specific facts provided.",
|
714 |
"document_content": [],
|
715 |
"audio_data": [],
|
716 |
"url_content": [],
|
717 |
"social_content": []
|
718 |
}
|
719 |
+
raw_transcriptions = "" # Initialize transcription log
|
720 |
|
721 |
+
# --- Process Inputs ---
|
722 |
+
logger.info("Processing document inputs...")
|
723 |
+
for i, doc_file in enumerate(doc_files):
|
724 |
+
if doc_file and hasattr(doc_file, 'name'):
|
725 |
+
try:
|
726 |
+
content = read_document(doc_file.name) # doc_file.name is the temp path
|
727 |
+
if content and not content.startswith("Error"):
|
728 |
+
# Truncate long documents for the knowledge base summary
|
729 |
+
doc_excerpt = (content[:1000] + "... [document truncated]") if len(content) > 1000 else content
|
730 |
+
knowledge_base["document_content"].append(f"[Document {i+1} Source: {os.path.basename(doc_file.name)}]\n{doc_excerpt}")
|
731 |
+
else:
|
732 |
+
logger.warning(f"Skipping document {i+1} due to read error or empty content: {content}")
|
733 |
+
except Exception as e:
|
734 |
+
logger.error(f"Failed to process document {i+1} ({doc_file.name}): {e}")
|
735 |
+
# No cleanup needed here, Gradio handles temp file uploads
|
736 |
+
|
737 |
+
logger.info("Processing URL inputs...")
|
738 |
+
for i, url in enumerate(url_inputs):
|
739 |
+
if url and isinstance(url, str) and url.strip().startswith('http'):
|
740 |
+
try:
|
741 |
+
content = read_url(url)
|
742 |
+
if content and not content.startswith("Error"):
|
743 |
+
# Content is already truncated in read_url if needed
|
744 |
+
knowledge_base["url_content"].append(f"[URL {i+1} Source: {url}]\n{content}")
|
745 |
+
else:
|
746 |
+
logger.warning(f"Skipping URL {i+1} ({url}) due to read error or empty content: {content}")
|
747 |
+
except Exception as e:
|
748 |
+
logger.error(f"Failed to process URL {i+1} ({url}): {e}")
|
749 |
|
750 |
+
logger.info("Processing audio/video inputs...")
|
751 |
+
has_audio_source = False
|
752 |
+
for i in range(num_audio_sources):
|
753 |
+
start_idx = i * num_audio_inputs_per_source
|
754 |
+
audio_file = audio_inputs_flat[start_idx]
|
755 |
+
name = audio_inputs_flat[start_idx + 1] or f"Source {i+1}"
|
756 |
+
position = audio_inputs_flat[start_idx + 2] or "N/A"
|
757 |
+
|
758 |
+
if audio_file and hasattr(audio_file, 'name'):
|
759 |
+
# Store info for transcription later
|
760 |
+
knowledge_base["audio_data"].append({
|
761 |
+
"file_path": audio_file.name, # Use the temp path
|
762 |
+
"name": name,
|
763 |
+
"position": position,
|
764 |
+
"original_filename": os.path.basename(audio_file.name) # Keep original for logs
|
765 |
+
})
|
766 |
+
has_audio_source = True
|
767 |
+
logger.info(f"Added audio source {i+1}: {name} ({position}) - File: {knowledge_base['audio_data'][-1]['original_filename']}")
|
768 |
+
|
769 |
+
logger.info("Processing social media inputs...")
|
770 |
+
has_social_source = False
|
771 |
+
for i in range(num_social_sources):
|
772 |
+
start_idx = i * num_social_inputs_per_source
|
773 |
+
social_url = social_inputs_flat[start_idx]
|
774 |
+
social_name = social_inputs_flat[start_idx + 1] or f"Social Source {i+1}"
|
775 |
+
social_context = social_inputs_flat[start_idx + 2] or "N/A"
|
776 |
+
|
777 |
+
if social_url and isinstance(social_url, str) and social_url.strip().startswith('http'):
|
778 |
+
try:
|
779 |
+
logger.info(f"Processing social media URL {i+1}: {social_url}")
|
780 |
+
social_data = process_social_media_url(social_url)
|
781 |
+
if social_data:
|
782 |
+
knowledge_base["social_content"].append({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
783 |
"url": social_url,
|
784 |
+
"name": social_name,
|
785 |
+
"context": social_context,
|
786 |
+
"text": social_data.get("text", ""),
|
787 |
+
"video_transcription": social_data.get("video", "") # Store potential transcription
|
788 |
})
|
789 |
+
has_social_source = True
|
790 |
+
logger.info(f"Added social source {i+1}: {social_name} ({social_context}) from {social_url}")
|
791 |
+
else:
|
792 |
+
logger.warning(f"Could not retrieve any content for social URL {i+1}: {social_url}")
|
793 |
+
except Exception as e:
|
794 |
+
logger.error(f"Failed to process social URL {i+1} ({social_url}): {e}")
|
795 |
|
|
|
|
|
796 |
|
797 |
+
# --- Transcribe Audio/Video ---
|
798 |
+
# Only initialize Whisper if needed
|
799 |
+
transcriptions_for_prompt = ""
|
800 |
+
if has_audio_source or any(sc.get("video_transcription") == "[NEEDS_TRANSCRIPTION]" for sc in knowledge_base["social_content"]): # Check if transcription actually needed
|
801 |
+
logger.info("Audio sources detected, ensuring Whisper model is ready...")
|
802 |
+
try:
|
803 |
+
model_manager.check_whisper_initialized()
|
804 |
+
except Exception as whisper_init_err:
|
805 |
+
logger.error(f"FATAL: Whisper model initialization failed: {whisper_init_err}. Cannot transcribe.")
|
806 |
+
# Add error message to raw transcriptions and continue without transcriptions
|
807 |
+
raw_transcriptions += f"[ERROR] Whisper model failed to load. Audio sources could not be transcribed: {whisper_init_err}\n\n"
|
808 |
+
# Optionally return an error message immediately?
|
809 |
+
# return f"Error: Could not initialize transcription model. {whisper_init_err}", raw_transcriptions
|
810 |
+
|
811 |
+
if model_manager.whisper_model: # Proceed only if whisper loaded successfully
|
812 |
+
logger.info("Transcribing collected audio sources...")
|
813 |
+
for idx, data in enumerate(knowledge_base["audio_data"]):
|
814 |
+
try:
|
815 |
+
logger.info(f"Transcribing audio source {idx+1}: {data['original_filename']} ({data['name']}, {data['position']})")
|
816 |
+
transcription = transcribe_audio_or_video(data["file_path"])
|
817 |
+
if transcription and not transcription.startswith("Error"):
|
818 |
+
quote = f'"{transcription}" - {data["name"]}, {data["position"]}'
|
819 |
+
transcriptions_for_prompt += f"{quote}\n\n"
|
820 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n"{transcription}"\n\n'
|
821 |
+
else:
|
822 |
+
logger.warning(f"Transcription failed or returned error for audio source {idx+1}: {transcription}")
|
823 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {transcription}]\n\n'
|
824 |
+
except Exception as e:
|
825 |
+
logger.error(f"Error during transcription for audio source {idx+1} ({data['original_filename']}): {e}")
|
826 |
+
logger.error(traceback.format_exc())
|
827 |
+
raw_transcriptions += f'[Audio/Video {idx + 1}: {data["original_filename"]} ({data["name"]}, {data["position"]})]\n[Error during transcription: {e}]\n\n'
|
828 |
+
# Gradio handles cleanup of the uploaded temp file audio_file.name
|
829 |
|
830 |
+
logger.info("Adding social media content to prompt data...")
|
831 |
for idx, data in enumerate(knowledge_base["social_content"]):
|
832 |
+
source_id = f'[Social Media {idx+1}: {data["url"]} ({data["name"]}, {data["context"]})]'
|
833 |
+
has_content = False
|
834 |
+
if data["text"] and not data["text"].startswith("Error"):
|
835 |
+
# Truncate long text for the prompt, but keep full in knowledge base maybe?
|
836 |
+
text_excerpt = (data["text"][:500] + "...[text truncated]") if len(data["text"]) > 500 else data["text"]
|
837 |
+
social_text_prompt = f'{source_id} - Text Content:\n"{text_excerpt}"\n\n'
|
838 |
+
transcriptions_for_prompt += social_text_prompt # Add text content as if it were a quote/source
|
839 |
+
raw_transcriptions += f"{source_id}\nText Content:\n{data['text']}\n\n" # Log full text
|
840 |
+
has_content = True
|
841 |
+
if data["video_transcription"] and not data["video_transcription"].startswith("Error"):
|
842 |
+
social_video_prompt = f'{source_id} - Video Transcription:\n"{data["video_transcription"]}"\n\n'
|
843 |
+
transcriptions_for_prompt += social_video_prompt
|
844 |
+
raw_transcriptions += f"{source_id}\nVideo Transcription:\n{data['video_transcription']}\n\n"
|
845 |
+
has_content = True
|
846 |
+
if not has_content:
|
847 |
+
raw_transcriptions += f"{source_id}\n[No usable text or video transcription found]\n\n"
|
848 |
+
|
849 |
+
|
850 |
+
# --- Prepare Final Prompt ---
|
851 |
+
# Combine document and URL summaries
|
852 |
+
document_summary = "\n\n".join(knowledge_base["document_content"]) if knowledge_base["document_content"] else "No document content provided."
|
853 |
+
url_summary = "\n\n".join(knowledge_base["url_content"]) if knowledge_base["url_content"] else "No URL content provided."
|
854 |
+
transcription_summary = transcriptions_for_prompt if transcriptions_for_prompt else "No usable transcriptions available."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
855 |
|
856 |
+
# Construct the prompt for the LLM
|
857 |
+
prompt = f"""<s>[INST] You are a professional news writer. Your task is to synthesize information from various sources into a coherent news article.
|
858 |
|
859 |
+
Primary Instructions: {knowledge_base["instructions"]}
|
860 |
+
Key Facts to Include: {knowledge_base["facts"]}
|
861 |
|
862 |
+
Supporting Information:
|
863 |
|
864 |
+
Document Content Summary:
|
865 |
+
{document_summary}
|
866 |
|
867 |
+
Web Content Summary (from URLs):
|
868 |
+
{url_summary}
|
869 |
|
870 |
+
Transcribed Quotes/Content (Use these directly or indirectly):
|
871 |
+
{transcription_summary}
|
872 |
|
873 |
+
Article Requirements:
|
874 |
+
- Title: Create a concise and informative title for the article.
|
875 |
+
- Hook: Write a compelling 15-word (approx.) hook sentence that complements the title.
|
876 |
+
- Body: Write the main news article body, aiming for approximately {size} words.
|
877 |
+
- Tone: Adopt a {tone} tone throughout the article.
|
878 |
+
- 5 Ws: Ensure the first paragraph addresses the core questions (Who, What, When, Where, Why).
|
879 |
+
- Quotes: Incorporate relevant information from the 'Transcribed Quotes/Content' section. Aim to use quotes where appropriate, but synthesize information rather than just listing quotes. Use quotation marks (" ") for direct quotes attributed correctly (e.g., based on name/position provided).
|
880 |
+
- Style: Adhere to a professional journalistic style. Be objective and factual.
|
881 |
+
- Accuracy: Do NOT invent information. Stick strictly to the provided facts, instructions, and source materials. If information is contradictory or missing, state that or omit the detail.
|
882 |
+
- Structure: Organize the article logically with clear paragraphs.
|
883 |
+
|
884 |
+
Begin the article now. [/INST]
|
885 |
+
Article Draft:
|
886 |
+
"""
|
887 |
+
|
888 |
+
# Log the prompt length (useful for debugging context limits)
|
889 |
+
logger.info(f"Generated prompt length: {len(prompt.split())} words / {len(prompt)} characters.")
|
890 |
+
# Avoid logging the full prompt if it's too long or contains sensitive info
|
891 |
+
# logger.debug(f"Generated Prompt:\n{prompt}")
|
892 |
+
|
893 |
+
# --- Generate News Article ---
|
894 |
+
logger.info("Generating news article with LLM...")
|
895 |
+
generation_start_time = time.time()
|
896 |
+
|
897 |
+
# Estimate max_new_tokens based on requested size + buffer
|
898 |
+
# Add buffer for title, hook, and potential verbosity
|
899 |
+
estimated_tokens_per_word = 1.5
|
900 |
+
max_new_tokens = int(size * estimated_tokens_per_word + 150) # size words + buffer
|
901 |
+
# Ensure max_new_tokens doesn't exceed model limits (adjust based on model's max context)
|
902 |
+
model_max_length = 2048 # Typical for TinyLlama, but check specific model card
|
903 |
+
# Calculate available space for generation
|
904 |
+
# Note: This token count is approximate. Precise tokenization is needed for accuracy.
|
905 |
+
# prompt_tokens = len(model_manager.tokenizer.encode(prompt)) # More accurate but slower
|
906 |
+
prompt_tokens_estimate = len(prompt) // 3 # Rough estimate
|
907 |
+
max_new_tokens = min(max_new_tokens, model_max_length - prompt_tokens_estimate - 50) # Leave buffer
|
908 |
+
max_new_tokens = max(max_new_tokens, 100) # Ensure at least a minimum generation length
|
909 |
+
|
910 |
+
logger.info(f"Requesting max_new_tokens: {max_new_tokens}")
|
911 |
|
912 |
try:
|
913 |
+
# Generate using the pipeline
|
914 |
+
outputs = model_manager.text_pipeline(
|
|
|
|
|
|
|
915 |
prompt,
|
916 |
+
max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
|
917 |
+
do_sample=True,
|
918 |
+
temperature=0.7, # Standard temperature for creative but factual
|
919 |
+
top_p=0.95,
|
920 |
+
top_k=50, # Consider adding top_k
|
921 |
+
repetition_penalty=1.15, # Adjusted penalty
|
922 |
+
pad_token_id=model_manager.tokenizer.eos_token_id,
|
923 |
+
num_return_sequences=1
|
|
|
|
|
|
|
|
|
924 |
)
|
925 |
+
|
926 |
+
# Extract generated text
|
927 |
+
generated_text = outputs[0]['generated_text']
|
928 |
+
|
929 |
+
# Clean up the result by removing the prompt
|
930 |
+
# Find the end of the prompt marker [/INST] and take text after it
|
931 |
+
inst_marker = "[/INST]"
|
932 |
+
marker_pos = generated_text.find(inst_marker)
|
933 |
+
if marker_pos != -1:
|
934 |
+
news_article = generated_text[marker_pos + len(inst_marker):].strip()
|
935 |
+
# Further clean potentially leading "Article Draft:" if model included it
|
936 |
+
if news_article.startswith("Article Draft:"):
|
937 |
+
news_article = news_article[len("Article Draft:"):].strip()
|
938 |
else:
|
939 |
+
# Fallback: Try removing the input prompt string itself (less reliable)
|
940 |
+
if prompt in generated_text:
|
941 |
+
news_article = generated_text.replace(prompt, "", 1).strip()
|
942 |
+
else:
|
943 |
+
# If prompt not found exactly, assume the output is only the generation
|
944 |
+
# This might happen if the pipeline handles prompt removal internally sometimes
|
945 |
+
news_article = generated_text
|
946 |
+
logger.warning("Prompt marker '[/INST]' not found in LLM output. Returning full output.")
|
947 |
+
|
948 |
+
|
949 |
+
generation_time = time.time() - generation_start_time
|
950 |
+
logger.info(f"News generation completed in {generation_time:.2f} seconds. Output length: {len(news_article)} characters.")
|
951 |
+
|
952 |
+
except torch.cuda.OutOfMemoryError as oom_error:
|
953 |
+
logger.error(f"CUDA Out of Memory error during LLM generation: {oom_error}")
|
954 |
+
logger.error(traceback.format_exc())
|
955 |
+
model_manager.reset_models(force=True) # Attempt to recover
|
956 |
+
raise RuntimeError("Generation failed due to insufficient GPU memory. Please try reducing article size or complexity.") from oom_error
|
957 |
except Exception as gen_error:
|
958 |
+
logger.error(f"Error during text generation pipeline: {str(gen_error)}")
|
959 |
+
logger.error(traceback.format_exc())
|
960 |
+
raise RuntimeError(f"LLM generation failed: {gen_error}") from gen_error
|
961 |
+
|
962 |
+
total_time = time.time() - request_start_time
|
963 |
+
logger.info(f"Total request processing time: {total_time:.2f} seconds.")
|
964 |
+
|
965 |
+
# Return the generated article and the log of raw transcriptions
|
966 |
+
return news_article, raw_transcriptions.strip()
|
967 |
|
968 |
except Exception as e:
|
969 |
+
total_time = time.time() - request_start_time
|
970 |
+
logger.error(f"Error in generate_news function after {total_time:.2f} seconds: {str(e)}")
|
971 |
+
logger.error(traceback.format_exc())
|
972 |
+
# Attempt to reset models to recover state if possible
|
973 |
try:
|
974 |
model_manager.reset_models(force=True)
|
975 |
except Exception as reset_error:
|
976 |
+
logger.error(f"Failed to reset models after error: {str(reset_error)}")
|
977 |
+
# Return error messages to the UI
|
978 |
+
error_message = f"Error generating the news article: {str(e)}"
|
979 |
+
transcription_log = raw_transcriptions.strip() + f"\n\n[ERROR] News generation failed: {str(e)}"
|
980 |
+
return error_message, transcription_log
|
981 |
|
982 |
def create_demo():
|
983 |
+
"""Creates the Gradio interface"""
|
984 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
985 |
gr.Markdown("# π° NewsIA - AI News Generator")
|
986 |
+
gr.Markdown("Create professional news articles from multiple information sources.")
|
987 |
+
|
988 |
+
# Store all input components for easy access/reset
|
989 |
+
all_inputs = []
|
990 |
+
|
991 |
with gr.Row():
|
992 |
with gr.Column(scale=2):
|
993 |
instructions = gr.Textbox(
|
994 |
+
label="Instructions for the News Article",
|
995 |
+
placeholder="Enter specific instructions for generating your news article (e.g., focus on the economic impact)",
|
996 |
+
lines=2,
|
997 |
+
value=""
|
998 |
)
|
999 |
+
all_inputs.append(instructions)
|
1000 |
+
|
1001 |
facts = gr.Textbox(
|
1002 |
+
label="Main Facts",
|
1003 |
+
placeholder="Describe the most important facts the news should include (e.g., Event name, date, location, key people involved)",
|
1004 |
+
lines=4,
|
1005 |
+
value=""
|
1006 |
)
|
1007 |
+
all_inputs.append(facts)
|
1008 |
+
|
1009 |
with gr.Row():
|
1010 |
+
size_slider = gr.Slider(
|
1011 |
label="Approximate Length (words)",
|
1012 |
minimum=100,
|
1013 |
+
maximum=700, # Increased max size
|
1014 |
value=250,
|
1015 |
step=50
|
1016 |
)
|
1017 |
+
all_inputs.append(size_slider)
|
1018 |
+
|
1019 |
+
tone_dropdown = gr.Dropdown(
|
1020 |
+
label="Tone of the News Article",
|
1021 |
+
choices=["neutral", "serious", "formal", "urgent", "investigative", "human-interest", "lighthearted"],
|
1022 |
value="neutral"
|
1023 |
)
|
1024 |
+
all_inputs.append(tone_dropdown)
|
1025 |
|
1026 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
1027 |
with gr.Tabs():
|
1028 |
with gr.TabItem("π Documents"):
|
1029 |
+
gr.Markdown("Upload relevant documents (PDF, DOCX, XLSX, CSV). Max 5.")
|
1030 |
+
doc_inputs = []
|
1031 |
for i in range(1, 6):
|
1032 |
+
doc_file = gr.File(
|
1033 |
label=f"Document {i}",
|
1034 |
+
file_types=["pdf", ".docx", ".xlsx", ".csv"], # Explicit extensions for clarity
|
1035 |
+
file_count="single" # Ensure single file per component
|
1036 |
)
|
1037 |
+
doc_inputs.append(doc_file)
|
1038 |
+
all_inputs.extend(doc_inputs)
|
1039 |
|
1040 |
with gr.TabItem("π Audio/Video"):
|
1041 |
+
gr.Markdown("Upload audio or video files for transcription (MP3, WAV, MP4, MOV, etc.). Max 5 sources.")
|
1042 |
+
audio_video_inputs = []
|
1043 |
+
for i in range(1, 6):
|
1044 |
with gr.Group():
|
1045 |
gr.Markdown(f"**Source {i}**")
|
1046 |
+
audio_file = gr.File(
|
1047 |
+
label=f"Audio/Video File {i}",
|
1048 |
file_types=["audio", "video"]
|
1049 |
)
|
1050 |
with gr.Row():
|
1051 |
+
speaker_name = gr.Textbox(
|
1052 |
+
label="Speaker Name",
|
1053 |
+
placeholder="Name of the interviewee or speaker",
|
1054 |
+
value=""
|
1055 |
)
|
1056 |
+
speaker_role = gr.Textbox(
|
1057 |
+
label="Role/Position",
|
1058 |
+
placeholder="Speaker's title or role",
|
1059 |
+
value=""
|
1060 |
)
|
1061 |
+
audio_video_inputs.append(audio_file)
|
1062 |
+
audio_video_inputs.append(speaker_name)
|
1063 |
+
audio_video_inputs.append(speaker_role)
|
1064 |
+
all_inputs.extend(audio_video_inputs)
|
1065 |
|
1066 |
with gr.TabItem("π URLs"):
|
1067 |
+
gr.Markdown("Add URLs to relevant web pages or articles. Max 5.")
|
1068 |
+
url_inputs = []
|
1069 |
+
for i in range(1, 6):
|
1070 |
+
url_textbox = gr.Textbox(
|
1071 |
label=f"URL {i}",
|
1072 |
+
placeholder="https://example.com/article",
|
1073 |
+
value=""
|
1074 |
)
|
1075 |
+
url_inputs.append(url_textbox)
|
1076 |
+
all_inputs.extend(url_inputs)
|
1077 |
|
1078 |
with gr.TabItem("π± Social Media"):
|
1079 |
+
gr.Markdown("Add URLs to social media posts (e.g., Twitter, YouTube, TikTok). Max 3.")
|
1080 |
+
social_inputs = []
|
1081 |
+
for i in range(1, 4):
|
1082 |
with gr.Group():
|
1083 |
+
gr.Markdown(f"**Social Media Source {i}**")
|
1084 |
+
social_url_textbox = gr.Textbox(
|
1085 |
+
label=f"Post URL",
|
1086 |
+
placeholder="https://twitter.com/user/status/...",
|
1087 |
+
value=""
|
1088 |
)
|
1089 |
with gr.Row():
|
1090 |
+
social_name_textbox = gr.Textbox(
|
1091 |
+
label=f"Account Name/User",
|
1092 |
+
placeholder="Name or handle (e.g., @username)",
|
1093 |
+
value=""
|
1094 |
)
|
1095 |
+
social_context_textbox = gr.Textbox(
|
1096 |
+
label=f"Context",
|
1097 |
+
placeholder="Brief context (e.g., statement on event X)",
|
1098 |
+
value=""
|
1099 |
)
|
1100 |
+
social_inputs.append(social_url_textbox)
|
1101 |
+
social_inputs.append(social_name_textbox)
|
1102 |
+
social_inputs.append(social_context_textbox)
|
1103 |
+
all_inputs.extend(social_inputs)
|
1104 |
+
|
1105 |
|
1106 |
with gr.Row():
|
1107 |
+
generate_button = gr.Button("β¨ Generate News Article", variant="primary")
|
1108 |
+
clear_button = gr.Button("π Clear All Inputs")
|
1109 |
|
1110 |
with gr.Tabs():
|
1111 |
+
with gr.TabItem("π Generated News Article"):
|
1112 |
news_output = gr.Textbox(
|
1113 |
+
label="Draft News Article",
|
1114 |
+
lines=20, # Increased lines
|
1115 |
+
show_copy_button=True,
|
1116 |
+
value=""
|
1117 |
)
|
1118 |
+
with gr.TabItem("ποΈ Source Transcriptions & Logs"):
|
|
|
1119 |
transcriptions_output = gr.Textbox(
|
1120 |
+
label="Transcriptions and Processing Log",
|
1121 |
+
lines=15, # Increased lines
|
1122 |
+
show_copy_button=True,
|
1123 |
+
value=""
|
1124 |
)
|
1125 |
|
1126 |
+
# --- Event Handlers ---
|
1127 |
+
# Define outputs
|
1128 |
+
outputs_list = [news_output, transcriptions_output]
|
1129 |
+
|
1130 |
+
# Generate button click
|
1131 |
+
generate_button.click(
|
1132 |
fn=generate_news,
|
1133 |
+
inputs=all_inputs, # Pass the consolidated list
|
1134 |
+
outputs=outputs_list
|
1135 |
)
|
1136 |
+
|
1137 |
+
# Clear button click
|
1138 |
+
def clear_all_inputs_and_outputs():
|
1139 |
+
# Return a list of default values matching the number and type of inputs + outputs
|
1140 |
+
reset_values = []
|
1141 |
+
for input_comp in all_inputs:
|
1142 |
+
# Default for Textbox, Dropdown is "", for Slider is its default, for File is None
|
1143 |
+
if isinstance(input_comp, (gr.Textbox, gr.Dropdown)):
|
1144 |
+
reset_values.append("")
|
1145 |
+
elif isinstance(input_comp, gr.Slider):
|
1146 |
+
# Find the original default value if needed, or just use a sensible default
|
1147 |
+
reset_values.append(250) # Reset slider to default
|
1148 |
+
elif isinstance(input_comp, gr.File):
|
1149 |
+
reset_values.append(None)
|
1150 |
+
else:
|
1151 |
+
reset_values.append(None) # Default for unknown/other types
|
1152 |
+
|
1153 |
+
# Add default values for the output fields
|
1154 |
+
reset_values.extend(["", ""]) # Two Textbox outputs
|
1155 |
+
|
1156 |
+
# Also reset the models in the background
|
1157 |
+
model_manager.reset_models(force=True)
|
1158 |
+
logger.info("UI cleared and models reset.")
|
1159 |
+
|
1160 |
+
return reset_values
|
1161 |
+
|
1162 |
+
clear_button.click(
|
1163 |
+
fn=clear_all_inputs_and_outputs,
|
1164 |
+
inputs=None, # No inputs needed for the clear function itself
|
1165 |
+
outputs=all_inputs + outputs_list # The list of components to clear
|
1166 |
)
|
1167 |
|
1168 |
+
# Add event handler to reset models when the Gradio app closes or reloads (if possible)
|
1169 |
+
# demo.unload(model_manager.reset_models, inputs=None, outputs=None) # Might not work reliably in Spaces
|
1170 |
+
|
1171 |
return demo
|
1172 |
|
1173 |
if __name__ == "__main__":
|
1174 |
+
logger.info("Starting NewsIA application...")
|
1175 |
+
|
1176 |
+
# Optional: Pre-initialize Whisper on startup if desired and resources allow
|
1177 |
+
# This can make the first transcription faster but uses GPU resources immediately.
|
1178 |
+
# Consider enabling only if transcriptions are very common.
|
1179 |
+
# try:
|
1180 |
+
# logger.info("Attempting to pre-initialize Whisper model...")
|
1181 |
+
# model_manager.initialize_whisper()
|
1182 |
+
# except Exception as e:
|
1183 |
+
# logger.warning(f"Pre-initialization of Whisper model failed (will load on demand): {str(e)}")
|
1184 |
+
|
1185 |
+
# Create the Gradio Demo
|
1186 |
+
news_demo = create_demo()
|
1187 |
+
|
1188 |
+
# Configure the queue - remove concurrency_count and max_size
|
1189 |
+
# Use default queue settings, suitable for most Spaces environments
|
1190 |
+
news_demo.queue()
|
1191 |
+
|
1192 |
+
# Launch the Gradio app
|
1193 |
+
logger.info("Launching Gradio interface...")
|
1194 |
+
news_demo.launch(
|
1195 |
+
server_name="0.0.0.0", # Necessary for Docker/Spaces
|
1196 |
+
server_port=7860,
|
1197 |
+
# share=True # Share=True is often handled by Spaces automatically, can be removed
|
1198 |
+
# debug=True # Enable for more detailed Gradio logs if needed
|
1199 |
+
)
|
1200 |
+
logger.info("NewsIA application finished.")
|