CamiloVega commited on
Commit
7a1615b
·
verified ·
1 Parent(s): ef7abd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -100
app.py CHANGED
@@ -17,9 +17,7 @@ from functools import lru_cache
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
- from transformers import AutoTokenizer, BitsAndBytesConfig
21
- from unsloth import FastLanguageModel
22
- import tqdm
23
 
24
  # Configure logging
25
  logging.basicConfig(
@@ -46,15 +44,17 @@ class ModelManager:
46
  if not self._initialized:
47
  self.tokenizer = None
48
  self.model = None
 
49
  self.whisper_model = None
50
  self._initialized = True
51
  self.last_used = time.time()
52
 
53
  @spaces.GPU()
54
  def initialize_llm(self):
55
- """Initialize LLM model with optimization"""
56
  try:
57
- MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
 
58
 
59
  logger.info("Loading tokenizer...")
60
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -62,43 +62,30 @@ class ModelManager:
62
  token=HUGGINGFACE_TOKEN,
63
  use_fast=True,
64
  )
65
- self.tokenizer.pad_token = self.tokenizer.eos_token
66
 
67
- try:
68
- # Try with unsloth first
69
- logger.info("Attempting to load model with unsloth optimization...")
70
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
71
- model_name=MODEL_NAME,
72
- token=HUGGINGFACE_TOKEN,
73
- load_in_8bit=True,
74
- max_seq_length=2048,
75
- device_map="auto"
76
- )
77
-
78
- # Optimize with unsloth
79
- self.model = FastLanguageModel.get_peft_model(
80
- self.model,
81
- r=8,
82
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
83
- lora_alpha=8,
84
- bias="none"
85
- )
86
-
87
- logger.info("Model loaded successfully with unsloth")
88
-
89
- except Exception as unsloth_error:
90
- # Fallback to standard transformers
91
- logger.warning(f"Unsloth optimization failed: {str(unsloth_error)}. Falling back to standard model.")
92
- from transformers import AutoModelForCausalLM
93
-
94
- self.model = AutoModelForCausalLM.from_pretrained(
95
- MODEL_NAME,
96
- token=HUGGINGFACE_TOKEN,
97
- device_map="auto",
98
- torch_dtype=torch.float16,
99
- load_in_8bit=True
100
- )
101
- logger.info("Model loaded with standard transformers")
102
 
103
  logger.info("LLM initialized successfully")
104
  self.last_used = time.time()
@@ -128,7 +115,7 @@ class ModelManager:
128
 
129
  def check_llm_initialized(self):
130
  """Check if LLM is initialized and initialize if needed"""
131
- if self.tokenizer is None or self.model is None:
132
  logger.info("LLM not initialized, initializing...")
133
  self.initialize_llm()
134
  self.last_used = time.time()
@@ -154,11 +141,15 @@ class ModelManager:
154
  if hasattr(self, 'tokenizer') and self.tokenizer is not None:
155
  del self.tokenizer
156
 
 
 
 
157
  if hasattr(self, 'whisper_model') and self.whisper_model is not None:
158
  del self.whisper_model
159
 
160
  self.tokenizer = None
161
  self.model = None
 
162
  self.whisper_model = None
163
 
164
  if torch.cuda.is_available():
@@ -490,62 +481,45 @@ Follow these requirements:
490
  - Do not invent information
491
  - Be rigorous with the provided facts [/INST]"""
492
 
493
- # Optimize for requested size
494
- max_new_tokens = min(int(size * 2.5), 1024) # Increased limit for better quality
495
-
496
- # Generate response using optimized unsloth model
497
- with torch.inference_mode():
498
- try:
499
- logger.info("Generating news article...")
500
- # Check if we're using unsloth or standard model
501
- is_unsloth = hasattr(model_manager.model, 'unsloth_module') if hasattr(model_manager.model, 'unsloth_module') else False
502
-
503
- # Prepare inputs
504
- inputs = model_manager.tokenizer(
505
- prompt,
506
- return_tensors="pt",
507
- add_special_tokens=False
508
- ).to(model_manager.model.device)
509
-
510
- # Generate with appropriate settings
511
- outputs = model_manager.model.generate(
512
- **inputs,
513
- max_new_tokens=max_new_tokens,
514
- do_sample=True,
515
- temperature=0.7,
516
- top_p=0.95,
517
- repetition_penalty=1.2,
518
- pad_token_id=model_manager.tokenizer.eos_token_id,
519
- use_cache=True
520
- )
521
-
522
- # Decode the generated text
523
- if is_unsloth:
524
- # Unsloth specific decoding
525
- generated_text = model_manager.tokenizer.decode(
526
- outputs[0][inputs.input_ids.shape[1]:],
527
- skip_special_tokens=True
528
- )
529
  else:
530
- # Standard transformers decoding
531
- generated_text = model_manager.tokenizer.decode(
532
- outputs[0],
533
- skip_special_tokens=True
534
- )
535
- # Remove the prompt from the generated text
536
- prompt_text = model_manager.tokenizer.decode(
537
- inputs.input_ids[0],
538
- skip_special_tokens=True
539
- )
540
- generated_text = generated_text.replace(prompt_text, "")
541
-
542
- # Clean up the generated text
543
- news_article = generated_text.strip()
544
- logger.info(f"News generation completed: {len(news_article)} chars")
545
-
546
- except Exception as gen_error:
547
- logger.error(f"Error in text generation: {str(gen_error)}")
548
- raise
549
 
550
  return news_article, raw_transcriptions
551
 
@@ -710,12 +684,11 @@ def create_demo():
710
  return demo
711
 
712
  if __name__ == "__main__":
713
- # Initialize models on startup to reduce first request latency
714
  try:
 
715
  model_manager.initialize_whisper()
716
- model_manager.initialize_llm()
717
  except Exception as e:
718
- logger.warning(f"Initial model loading failed: {str(e)}")
719
 
720
  demo = create_demo()
721
  demo.queue(concurrency_count=1, max_size=5)
 
17
  import gc
18
  import time
19
  from huggingface_hub import login
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
21
 
22
  # Configure logging
23
  logging.basicConfig(
 
44
  if not self._initialized:
45
  self.tokenizer = None
46
  self.model = None
47
+ self.pipeline = None
48
  self.whisper_model = None
49
  self._initialized = True
50
  self.last_used = time.time()
51
 
52
  @spaces.GPU()
53
  def initialize_llm(self):
54
+ """Initialize LLM model with standard transformers"""
55
  try:
56
+ # Use small model for ZeroGPU compatibility
57
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
58
 
59
  logger.info("Loading tokenizer...")
60
  self.tokenizer = AutoTokenizer.from_pretrained(
 
62
  token=HUGGINGFACE_TOKEN,
63
  use_fast=True,
64
  )
 
65
 
66
+ if self.tokenizer.pad_token is None:
67
+ self.tokenizer.pad_token = self.tokenizer.eos_token
68
+
69
+ # Basic memory settings for ZeroGPU
70
+ logger.info("Loading model...")
71
+ self.model = AutoModelForCausalLM.from_pretrained(
72
+ MODEL_NAME,
73
+ token=HUGGINGFACE_TOKEN,
74
+ device_map="auto",
75
+ torch_dtype=torch.float16,
76
+ low_cpu_mem_usage=True
77
+ )
78
+
79
+ # Create text generation pipeline
80
+ logger.info("Creating pipeline...")
81
+ self.pipeline = pipeline(
82
+ "text-generation",
83
+ model=self.model,
84
+ tokenizer=self.tokenizer,
85
+ torch_dtype=torch.float16,
86
+ device_map="auto",
87
+ max_length=2048
88
+ )
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  logger.info("LLM initialized successfully")
91
  self.last_used = time.time()
 
115
 
116
  def check_llm_initialized(self):
117
  """Check if LLM is initialized and initialize if needed"""
118
+ if self.tokenizer is None or self.model is None or self.pipeline is None:
119
  logger.info("LLM not initialized, initializing...")
120
  self.initialize_llm()
121
  self.last_used = time.time()
 
141
  if hasattr(self, 'tokenizer') and self.tokenizer is not None:
142
  del self.tokenizer
143
 
144
+ if hasattr(self, 'pipeline') and self.pipeline is not None:
145
+ del self.pipeline
146
+
147
  if hasattr(self, 'whisper_model') and self.whisper_model is not None:
148
  del self.whisper_model
149
 
150
  self.tokenizer = None
151
  self.model = None
152
+ self.pipeline = None
153
  self.whisper_model = None
154
 
155
  if torch.cuda.is_available():
 
481
  - Do not invent information
482
  - Be rigorous with the provided facts [/INST]"""
483
 
484
+ # Generate with standard pipeline
485
+ try:
486
+ logger.info("Generating news article...")
487
+
488
+ # Set max length based on requested size
489
+ max_length = min(len(prompt.split()) + size * 2, 2048)
490
+
491
+ # Generate using the pipeline
492
+ outputs = model_manager.pipeline(
493
+ prompt,
494
+ max_length=max_length,
495
+ do_sample=True,
496
+ temperature=0.7,
497
+ top_p=0.95,
498
+ repetition_penalty=1.2,
499
+ pad_token_id=model_manager.tokenizer.eos_token_id,
500
+ num_return_sequences=1
501
+ )
502
+
503
+ # Extract generated text
504
+ generated_text = outputs[0]['generated_text']
505
+
506
+ # Clean up the result by removing the prompt
507
+ if "[/INST]" in generated_text:
508
+ news_article = generated_text.split("[/INST]")[1].strip()
509
+ else:
510
+ # Try to extract the text after the prompt
511
+ prompt_words = prompt.split()[:50] # Use first 50 words to identify
512
+ prompt_fragment = " ".join(prompt_words)
513
+ if prompt_fragment in generated_text:
514
+ news_article = generated_text[generated_text.find(prompt_fragment) + len(prompt_fragment):].strip()
 
 
 
 
 
515
  else:
516
+ news_article = generated_text
517
+
518
+ logger.info(f"News generation completed: {len(news_article)} chars")
519
+
520
+ except Exception as gen_error:
521
+ logger.error(f"Error in text generation: {str(gen_error)}")
522
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  return news_article, raw_transcriptions
525
 
 
684
  return demo
685
 
686
  if __name__ == "__main__":
 
687
  try:
688
+ # Try initializing whisper model on startup
689
  model_manager.initialize_whisper()
 
690
  except Exception as e:
691
+ logger.warning(f"Initial whisper model loading failed: {str(e)}")
692
 
693
  demo = create_demo()
694
  demo.queue(concurrency_count=1, max_size=5)