CamiloVega commited on
Commit
fb3575f
·
verified ·
1 Parent(s): 9dd3257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -37
app.py CHANGED
@@ -52,7 +52,7 @@ class ModelManager:
52
 
53
  @spaces.GPU()
54
  def initialize_llm(self):
55
- """Initialize LLM model with unsloth optimization"""
56
  try:
57
  MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
58
 
@@ -64,36 +64,41 @@ class ModelManager:
64
  )
65
  self.tokenizer.pad_token = self.tokenizer.eos_token
66
 
67
- # Configure 4-bit quantization
68
- bnb_config = BitsAndBytesConfig(
69
- load_in_4bit=True,
70
- bnb_4bit_quant_type="nf4",
71
- bnb_4bit_compute_dtype=torch.float16,
72
- bnb_4bit_use_double_quant=True
73
- )
74
-
75
- logger.info("Loading and optimizing model with unsloth...")
76
- # Use unsloth to load and optimize the model
77
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
78
- model_name=MODEL_NAME,
79
- token=HUGGINGFACE_TOKEN,
80
- quantization_config=bnb_config,
81
- max_seq_length=2048,
82
- device_map="auto"
83
- )
84
-
85
- # Optimize with unsloth
86
- self.model = FastLanguageModel.get_peft_model(
87
- self.model,
88
- r=16,
89
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
90
- "gate_proj", "up_proj", "down_proj"],
91
- lora_alpha=16,
92
- lora_dropout=0,
93
- bias="none",
94
- use_gradient_checkpointing=True,
95
- random_state=3407
96
- )
 
 
 
 
 
97
 
98
  logger.info("LLM initialized successfully")
99
  self.last_used = time.time()
@@ -492,14 +497,17 @@ Follow these requirements:
492
  with torch.inference_mode():
493
  try:
494
  logger.info("Generating news article...")
495
- # Use unsloth's optimized generate method
 
 
 
496
  inputs = model_manager.tokenizer(
497
  prompt,
498
  return_tensors="pt",
499
  add_special_tokens=False
500
  ).to(model_manager.model.device)
501
 
502
- # Generate with optimized settings
503
  outputs = model_manager.model.generate(
504
  **inputs,
505
  max_new_tokens=max_new_tokens,
@@ -512,10 +520,24 @@ Follow these requirements:
512
  )
513
 
514
  # Decode the generated text
515
- generated_text = model_manager.tokenizer.decode(
516
- outputs[0][inputs.input_ids.shape[1]:],
517
- skip_special_tokens=True
518
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  # Clean up the generated text
521
  news_article = generated_text.strip()
 
52
 
53
  @spaces.GPU()
54
  def initialize_llm(self):
55
+ """Initialize LLM model with optimization"""
56
  try:
57
  MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
58
 
 
64
  )
65
  self.tokenizer.pad_token = self.tokenizer.eos_token
66
 
67
+ try:
68
+ # Try with unsloth first
69
+ logger.info("Attempting to load model with unsloth optimization...")
70
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
71
+ model_name=MODEL_NAME,
72
+ token=HUGGINGFACE_TOKEN,
73
+ load_in_8bit=True,
74
+ max_seq_length=2048,
75
+ device_map="auto"
76
+ )
77
+
78
+ # Optimize with unsloth
79
+ self.model = FastLanguageModel.get_peft_model(
80
+ self.model,
81
+ r=8,
82
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
83
+ lora_alpha=8,
84
+ bias="none"
85
+ )
86
+
87
+ logger.info("Model loaded successfully with unsloth")
88
+
89
+ except Exception as unsloth_error:
90
+ # Fallback to standard transformers
91
+ logger.warning(f"Unsloth optimization failed: {str(unsloth_error)}. Falling back to standard model.")
92
+ from transformers import AutoModelForCausalLM
93
+
94
+ self.model = AutoModelForCausalLM.from_pretrained(
95
+ MODEL_NAME,
96
+ token=HUGGINGFACE_TOKEN,
97
+ device_map="auto",
98
+ torch_dtype=torch.float16,
99
+ load_in_8bit=True
100
+ )
101
+ logger.info("Model loaded with standard transformers")
102
 
103
  logger.info("LLM initialized successfully")
104
  self.last_used = time.time()
 
497
  with torch.inference_mode():
498
  try:
499
  logger.info("Generating news article...")
500
+ # Check if we're using unsloth or standard model
501
+ is_unsloth = hasattr(model_manager.model, 'unsloth_module') if hasattr(model_manager.model, 'unsloth_module') else False
502
+
503
+ # Prepare inputs
504
  inputs = model_manager.tokenizer(
505
  prompt,
506
  return_tensors="pt",
507
  add_special_tokens=False
508
  ).to(model_manager.model.device)
509
 
510
+ # Generate with appropriate settings
511
  outputs = model_manager.model.generate(
512
  **inputs,
513
  max_new_tokens=max_new_tokens,
 
520
  )
521
 
522
  # Decode the generated text
523
+ if is_unsloth:
524
+ # Unsloth specific decoding
525
+ generated_text = model_manager.tokenizer.decode(
526
+ outputs[0][inputs.input_ids.shape[1]:],
527
+ skip_special_tokens=True
528
+ )
529
+ else:
530
+ # Standard transformers decoding
531
+ generated_text = model_manager.tokenizer.decode(
532
+ outputs[0],
533
+ skip_special_tokens=True
534
+ )
535
+ # Remove the prompt from the generated text
536
+ prompt_text = model_manager.tokenizer.decode(
537
+ inputs.input_ids[0],
538
+ skip_special_tokens=True
539
+ )
540
+ generated_text = generated_text.replace(prompt_text, "")
541
 
542
  # Clean up the generated text
543
  news_article = generated_text.strip()