Phi-2-fine-tuned-with-GRPO / model_handler.py
Ubuntu
CUDA fix
2eefbf0
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
def load_model_and_tokenizer():
"""Load the fine-tuned model and tokenizer, ensuring CPU compatibility."""
# Load tokenizer
base_model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
# Create offload directory if it doesn't exist
offload_dir = "offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Check if CUDA is available
use_cuda = torch.cuda.is_available()
try:
# First try loading with quantization if CUDA is available
if use_cuda:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=offload_dir,
load_in_8bit=True,
low_cpu_mem_usage=True
)
else:
# CPU-only loading without quantization
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float32,
device_map="auto",
offload_folder=offload_dir,
low_cpu_mem_usage=True
)
# Load adapter weights
model = PeftModel.from_pretrained(
base_model,
"phi2-grpo-qlora-final",
device_map="auto",
offload_folder=offload_dir
)
except Exception as e:
print(f"Error loading with adapter: {e}")
print("Falling back to base model only...")
# Fallback to just the base model if adapter loading fails
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
# Set to evaluation mode
model.eval()
return model, tokenizer
def generate_response(model, tokenizer, prompt, max_new_tokens=256):
"""Generate a streaming response from the model for the given prompt."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Store the input length to identify the response part
input_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
# Initialize generation
generated_ids = inputs.input_ids.clone()
past_key_values = None
response_text = ""
# Add stop sequences to detect natural endings
stop_sequences = ["\n\n", "\nExercise:", "\nQuestion:"]
# Generate tokens one by one
for _ in range(max_new_tokens):
with torch.no_grad():
# Forward pass
outputs = model(
input_ids=generated_ids[:, -1:] if past_key_values is not None else generated_ids,
past_key_values=past_key_values,
use_cache=True
)
# Get logits and past key values
logits = outputs.logits
past_key_values = outputs.past_key_values
# Sample next token
next_token_logits = logits[:, -1, :]
next_token_logits = next_token_logits / 0.7 # Apply temperature
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > 0.9
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = -float('Inf')
# Sample from the filtered distribution
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to generated ids
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Decode the current token
current_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Extract only the new part (response)
if len(current_text) > input_length:
new_text = current_text[input_length:]
# Get only the new characters
new_chars = new_text[len(response_text):]
response_text = new_text
# Yield the new characters for streaming
yield new_chars
# Stop if we generate an EOS token
if next_token[0, 0].item() == tokenizer.eos_token_id:
break
# Check for natural stopping points
if any(stop_seq in response_text for stop_seq in stop_sequences):
# If we find a stop sequence, only keep text up to that point
for stop_seq in stop_sequences:
if stop_seq in response_text:
stop_idx = response_text.find(stop_seq)
if stop_idx > 0: # Only trim if we have some content
yield response_text[len(response_text)-1:stop_idx] # Yield any remaining text
return