Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,9 +17,7 @@ from functools import lru_cache
|
|
17 |
import gc
|
18 |
import time
|
19 |
from huggingface_hub import login
|
20 |
-
from transformers import AutoTokenizer,
|
21 |
-
from unsloth import FastLanguageModel
|
22 |
-
import tqdm
|
23 |
|
24 |
# Configure logging
|
25 |
logging.basicConfig(
|
@@ -46,15 +44,17 @@ class ModelManager:
|
|
46 |
if not self._initialized:
|
47 |
self.tokenizer = None
|
48 |
self.model = None
|
|
|
49 |
self.whisper_model = None
|
50 |
self._initialized = True
|
51 |
self.last_used = time.time()
|
52 |
|
53 |
@spaces.GPU()
|
54 |
def initialize_llm(self):
|
55 |
-
"""Initialize LLM model with
|
56 |
try:
|
57 |
-
|
|
|
58 |
|
59 |
logger.info("Loading tokenizer...")
|
60 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
@@ -62,43 +62,30 @@ class ModelManager:
|
|
62 |
token=HUGGINGFACE_TOKEN,
|
63 |
use_fast=True,
|
64 |
)
|
65 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
# Fallback to standard transformers
|
91 |
-
logger.warning(f"Unsloth optimization failed: {str(unsloth_error)}. Falling back to standard model.")
|
92 |
-
from transformers import AutoModelForCausalLM
|
93 |
-
|
94 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
95 |
-
MODEL_NAME,
|
96 |
-
token=HUGGINGFACE_TOKEN,
|
97 |
-
device_map="auto",
|
98 |
-
torch_dtype=torch.float16,
|
99 |
-
load_in_8bit=True
|
100 |
-
)
|
101 |
-
logger.info("Model loaded with standard transformers")
|
102 |
|
103 |
logger.info("LLM initialized successfully")
|
104 |
self.last_used = time.time()
|
@@ -128,7 +115,7 @@ class ModelManager:
|
|
128 |
|
129 |
def check_llm_initialized(self):
|
130 |
"""Check if LLM is initialized and initialize if needed"""
|
131 |
-
if self.tokenizer is None or self.model is None:
|
132 |
logger.info("LLM not initialized, initializing...")
|
133 |
self.initialize_llm()
|
134 |
self.last_used = time.time()
|
@@ -154,11 +141,15 @@ class ModelManager:
|
|
154 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
155 |
del self.tokenizer
|
156 |
|
|
|
|
|
|
|
157 |
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
|
158 |
del self.whisper_model
|
159 |
|
160 |
self.tokenizer = None
|
161 |
self.model = None
|
|
|
162 |
self.whisper_model = None
|
163 |
|
164 |
if torch.cuda.is_available():
|
@@ -490,62 +481,45 @@ Follow these requirements:
|
|
490 |
- Do not invent information
|
491 |
- Be rigorous with the provided facts [/INST]"""
|
492 |
|
493 |
-
#
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
)
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
# Unsloth specific decoding
|
525 |
-
generated_text = model_manager.tokenizer.decode(
|
526 |
-
outputs[0][inputs.input_ids.shape[1]:],
|
527 |
-
skip_special_tokens=True
|
528 |
-
)
|
529 |
else:
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
inputs.input_ids[0],
|
538 |
-
skip_special_tokens=True
|
539 |
-
)
|
540 |
-
generated_text = generated_text.replace(prompt_text, "")
|
541 |
-
|
542 |
-
# Clean up the generated text
|
543 |
-
news_article = generated_text.strip()
|
544 |
-
logger.info(f"News generation completed: {len(news_article)} chars")
|
545 |
-
|
546 |
-
except Exception as gen_error:
|
547 |
-
logger.error(f"Error in text generation: {str(gen_error)}")
|
548 |
-
raise
|
549 |
|
550 |
return news_article, raw_transcriptions
|
551 |
|
@@ -710,12 +684,11 @@ def create_demo():
|
|
710 |
return demo
|
711 |
|
712 |
if __name__ == "__main__":
|
713 |
-
# Initialize models on startup to reduce first request latency
|
714 |
try:
|
|
|
715 |
model_manager.initialize_whisper()
|
716 |
-
model_manager.initialize_llm()
|
717 |
except Exception as e:
|
718 |
-
logger.warning(f"Initial model loading failed: {str(e)}")
|
719 |
|
720 |
demo = create_demo()
|
721 |
demo.queue(concurrency_count=1, max_size=5)
|
|
|
17 |
import gc
|
18 |
import time
|
19 |
from huggingface_hub import login
|
20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
|
21 |
|
22 |
# Configure logging
|
23 |
logging.basicConfig(
|
|
|
44 |
if not self._initialized:
|
45 |
self.tokenizer = None
|
46 |
self.model = None
|
47 |
+
self.pipeline = None
|
48 |
self.whisper_model = None
|
49 |
self._initialized = True
|
50 |
self.last_used = time.time()
|
51 |
|
52 |
@spaces.GPU()
|
53 |
def initialize_llm(self):
|
54 |
+
"""Initialize LLM model with standard transformers"""
|
55 |
try:
|
56 |
+
# Use small model for ZeroGPU compatibility
|
57 |
+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
58 |
|
59 |
logger.info("Loading tokenizer...")
|
60 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
62 |
token=HUGGINGFACE_TOKEN,
|
63 |
use_fast=True,
|
64 |
)
|
|
|
65 |
|
66 |
+
if self.tokenizer.pad_token is None:
|
67 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
68 |
+
|
69 |
+
# Basic memory settings for ZeroGPU
|
70 |
+
logger.info("Loading model...")
|
71 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
72 |
+
MODEL_NAME,
|
73 |
+
token=HUGGINGFACE_TOKEN,
|
74 |
+
device_map="auto",
|
75 |
+
torch_dtype=torch.float16,
|
76 |
+
low_cpu_mem_usage=True
|
77 |
+
)
|
78 |
+
|
79 |
+
# Create text generation pipeline
|
80 |
+
logger.info("Creating pipeline...")
|
81 |
+
self.pipeline = pipeline(
|
82 |
+
"text-generation",
|
83 |
+
model=self.model,
|
84 |
+
tokenizer=self.tokenizer,
|
85 |
+
torch_dtype=torch.float16,
|
86 |
+
device_map="auto",
|
87 |
+
max_length=2048
|
88 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
logger.info("LLM initialized successfully")
|
91 |
self.last_used = time.time()
|
|
|
115 |
|
116 |
def check_llm_initialized(self):
|
117 |
"""Check if LLM is initialized and initialize if needed"""
|
118 |
+
if self.tokenizer is None or self.model is None or self.pipeline is None:
|
119 |
logger.info("LLM not initialized, initializing...")
|
120 |
self.initialize_llm()
|
121 |
self.last_used = time.time()
|
|
|
141 |
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
142 |
del self.tokenizer
|
143 |
|
144 |
+
if hasattr(self, 'pipeline') and self.pipeline is not None:
|
145 |
+
del self.pipeline
|
146 |
+
|
147 |
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
|
148 |
del self.whisper_model
|
149 |
|
150 |
self.tokenizer = None
|
151 |
self.model = None
|
152 |
+
self.pipeline = None
|
153 |
self.whisper_model = None
|
154 |
|
155 |
if torch.cuda.is_available():
|
|
|
481 |
- Do not invent information
|
482 |
- Be rigorous with the provided facts [/INST]"""
|
483 |
|
484 |
+
# Generate with standard pipeline
|
485 |
+
try:
|
486 |
+
logger.info("Generating news article...")
|
487 |
+
|
488 |
+
# Set max length based on requested size
|
489 |
+
max_length = min(len(prompt.split()) + size * 2, 2048)
|
490 |
+
|
491 |
+
# Generate using the pipeline
|
492 |
+
outputs = model_manager.pipeline(
|
493 |
+
prompt,
|
494 |
+
max_length=max_length,
|
495 |
+
do_sample=True,
|
496 |
+
temperature=0.7,
|
497 |
+
top_p=0.95,
|
498 |
+
repetition_penalty=1.2,
|
499 |
+
pad_token_id=model_manager.tokenizer.eos_token_id,
|
500 |
+
num_return_sequences=1
|
501 |
+
)
|
502 |
+
|
503 |
+
# Extract generated text
|
504 |
+
generated_text = outputs[0]['generated_text']
|
505 |
+
|
506 |
+
# Clean up the result by removing the prompt
|
507 |
+
if "[/INST]" in generated_text:
|
508 |
+
news_article = generated_text.split("[/INST]")[1].strip()
|
509 |
+
else:
|
510 |
+
# Try to extract the text after the prompt
|
511 |
+
prompt_words = prompt.split()[:50] # Use first 50 words to identify
|
512 |
+
prompt_fragment = " ".join(prompt_words)
|
513 |
+
if prompt_fragment in generated_text:
|
514 |
+
news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
|
|
|
|
|
|
|
|
|
|
|
515 |
else:
|
516 |
+
news_article = generated_text
|
517 |
+
|
518 |
+
logger.info(f"News generation completed: {len(news_article)} chars")
|
519 |
+
|
520 |
+
except Exception as gen_error:
|
521 |
+
logger.error(f"Error in text generation: {str(gen_error)}")
|
522 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
return news_article, raw_transcriptions
|
525 |
|
|
|
684 |
return demo
|
685 |
|
686 |
if __name__ == "__main__":
|
|
|
687 |
try:
|
688 |
+
# Try initializing whisper model on startup
|
689 |
model_manager.initialize_whisper()
|
|
|
690 |
except Exception as e:
|
691 |
+
logger.warning(f"Initial whisper model loading failed: {str(e)}")
|
692 |
|
693 |
demo = create_demo()
|
694 |
demo.queue(concurrency_count=1, max_size=5)
|