File size: 5,582 Bytes
5b2a0f1
 
 
b2000f8
5b2a0f1
 
 
 
 
 
 
 
b2000f8
 
 
 
2eefbf0
 
5b2a0f1
2eefbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os

def load_model_and_tokenizer():
    """Load the fine-tuned model and tokenizer, ensuring CPU compatibility."""
    # Load tokenizer
    base_model_name = "microsoft/phi-2"
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Create offload directory if it doesn't exist
    offload_dir = "offload_dir"
    os.makedirs(offload_dir, exist_ok=True)
    
    # Check if CUDA is available
    use_cuda = torch.cuda.is_available()
    
    try:
        # First try loading with quantization if CUDA is available
        if use_cuda:
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                offload_folder=offload_dir,
                load_in_8bit=True,
                low_cpu_mem_usage=True
            )
        else:
            # CPU-only loading without quantization
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                torch_dtype=torch.float32,
                device_map="auto",
                offload_folder=offload_dir,
                low_cpu_mem_usage=True
            )
        
        # Load adapter weights
        model = PeftModel.from_pretrained(
            base_model, 
            "phi2-grpo-qlora-final",
            device_map="auto",
            offload_folder=offload_dir
        )
    except Exception as e:
        print(f"Error loading with adapter: {e}")
        print("Falling back to base model only...")
        # Fallback to just the base model if adapter loading fails
        model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float32,
            device_map="auto",
            low_cpu_mem_usage=True
        )
    
    # Set to evaluation mode
    model.eval()
    
    return model, tokenizer

def generate_response(model, tokenizer, prompt, max_new_tokens=256):
    """Generate a streaming response from the model for the given prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Store the input length to identify the response part
    input_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
    
    # Initialize generation
    generated_ids = inputs.input_ids.clone()
    past_key_values = None
    response_text = ""
    
    # Add stop sequences to detect natural endings
    stop_sequences = ["\n\n", "\nExercise:", "\nQuestion:"]
    
    # Generate tokens one by one
    for _ in range(max_new_tokens):
        with torch.no_grad():
            # Forward pass
            outputs = model(
                input_ids=generated_ids[:, -1:] if past_key_values is not None else generated_ids,
                past_key_values=past_key_values,
                use_cache=True
            )
            
            # Get logits and past key values
            logits = outputs.logits
            past_key_values = outputs.past_key_values
            
            # Sample next token
            next_token_logits = logits[:, -1, :]
            next_token_logits = next_token_logits / 0.7  # Apply temperature
            
            # Apply top-p sampling
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > 0.9
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = -float('Inf')
            
            # Sample from the filtered distribution
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to generated ids
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            
            # Decode the current token
            current_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            # Extract only the new part (response)
            if len(current_text) > input_length:
                new_text = current_text[input_length:]
                # Get only the new characters
                new_chars = new_text[len(response_text):]
                response_text = new_text
                
                # Yield the new characters for streaming
                yield new_chars
            
            # Stop if we generate an EOS token
            if next_token[0, 0].item() == tokenizer.eos_token_id:
                break
                
            # Check for natural stopping points
            if any(stop_seq in response_text for stop_seq in stop_sequences):
                # If we find a stop sequence, only keep text up to that point
                for stop_seq in stop_sequences:
                    if stop_seq in response_text:
                        stop_idx = response_text.find(stop_seq)
                        if stop_idx > 0:  # Only trim if we have some content
                            yield response_text[len(response_text)-1:stop_idx]  # Yield any remaining text
                            return