Ubuntu commited on
Commit
a4bab9c
·
1 Parent(s): ad4670f

inference script added

Browse files
Files changed (1) hide show
  1. compare_models.py +158 -0
compare_models.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import argparse
5
+ import time
6
+ import sys
7
+
8
+ def stream_response(model, tokenizer, prompt, max_new_tokens=256):
9
+ """Generate a streaming response from the model for the given prompt."""
10
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
11
+
12
+ # Store the input length to identify the response part
13
+ input_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
14
+
15
+ # Initialize generation
16
+ generated_ids = inputs.input_ids.clone()
17
+ past_key_values = None
18
+ response_text = ""
19
+
20
+ # Add stop sequences to detect natural endings
21
+ stop_sequences = ["\n\n", "\nExercise:", "\nQuestion:"]
22
+ was_truncated = False # Track if response was truncated
23
+
24
+ # Generate tokens one by one
25
+ for _ in range(max_new_tokens):
26
+ with torch.no_grad():
27
+ # Forward pass
28
+ outputs = model(
29
+ input_ids=generated_ids[:, -1:] if past_key_values is not None else generated_ids,
30
+ past_key_values=past_key_values,
31
+ use_cache=True
32
+ )
33
+
34
+ # Get logits and past key values
35
+ logits = outputs.logits
36
+ past_key_values = outputs.past_key_values
37
+
38
+ # Sample next token
39
+ next_token_logits = logits[:, -1, :]
40
+ next_token_logits = next_token_logits / 0.7 # Apply temperature
41
+
42
+ # Apply top-p sampling
43
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
44
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
45
+ sorted_indices_to_remove = cumulative_probs > 0.9
46
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
47
+ sorted_indices_to_remove[..., 0] = 0
48
+
49
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
50
+ next_token_logits[indices_to_remove] = -float('Inf')
51
+
52
+ # Sample from the filtered distribution
53
+ probs = torch.softmax(next_token_logits, dim=-1)
54
+ next_token = torch.multinomial(probs, num_samples=1)
55
+
56
+ # Append to generated ids
57
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
58
+
59
+ # Decode the current token
60
+ current_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
61
+
62
+ # Extract only the new part (response)
63
+ if len(current_text) > input_length:
64
+ new_text = current_text[input_length:]
65
+ # Print only the new characters
66
+ new_chars = new_text[len(response_text):]
67
+ sys.stdout.write(new_chars)
68
+ sys.stdout.flush()
69
+ response_text = new_text
70
+
71
+ # Add a small delay to simulate typing
72
+ time.sleep(0.01)
73
+
74
+ # Stop if we generate an EOS token
75
+ if next_token[0, 0].item() == tokenizer.eos_token_id:
76
+ break
77
+
78
+ # Check for natural stopping points
79
+ if any(stop_seq in response_text for stop_seq in stop_sequences):
80
+ # If we find a stop sequence, only keep text up to that point
81
+ for stop_seq in stop_sequences:
82
+ if stop_seq in response_text:
83
+ stop_idx = response_text.find(stop_seq)
84
+ if stop_idx > 0: # Only trim if we have some content
85
+ was_truncated = True
86
+ response_text = response_text[:stop_idx]
87
+ sys.stdout.write("\n") # Add a newline for cleaner output
88
+ sys.stdout.flush()
89
+ return response_text, was_truncated
90
+
91
+ # Return the full response
92
+ return response_text, was_truncated
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser(description="Compare base and fine-tuned Phi-2 models")
96
+ parser.add_argument("--base-only", action="store_true", help="Use only the base model")
97
+ parser.add_argument("--finetuned-only", action="store_true", help="Use only the fine-tuned model")
98
+ parser.add_argument("--adapter-path", type=str, default="./phi2-grpo-qlora-final",
99
+ help="Path to the fine-tuned adapter")
100
+ args = parser.parse_args()
101
+
102
+ # Load the base model and tokenizer
103
+ base_model_name = "microsoft/phi-2"
104
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
105
+
106
+ # Configure tokenizer
107
+ tokenizer.pad_token = tokenizer.eos_token
108
+
109
+ # Load models based on arguments
110
+ models = {}
111
+
112
+ if not args.finetuned_only:
113
+ print("Loading base model...")
114
+ base_model = AutoModelForCausalLM.from_pretrained(
115
+ base_model_name,
116
+ torch_dtype=torch.bfloat16,
117
+ device_map="auto"
118
+ )
119
+ models["Base Phi-2"] = base_model
120
+
121
+ if not args.base_only:
122
+ print(f"Loading fine-tuned model from {args.adapter_path}...")
123
+ # Load the base model first (with same quantization as during training)
124
+ base_model_for_ft = AutoModelForCausalLM.from_pretrained(
125
+ base_model_name,
126
+ torch_dtype=torch.bfloat16,
127
+ device_map="auto"
128
+ )
129
+ # Load the adapter on top of the base model
130
+ finetuned_model = PeftModel.from_pretrained(base_model_for_ft, args.adapter_path)
131
+ models["Fine-tuned Phi-2"] = finetuned_model
132
+
133
+ # Interactive prompt loop
134
+ print("\n" + "="*50)
135
+ print("Interactive Phi-2 Model Comparison (Streaming Mode)")
136
+ print("Type 'exit' to quit")
137
+ print("="*50 + "\n")
138
+
139
+ while True:
140
+ # Get user input
141
+ user_prompt = input("\nEnter your prompt: ")
142
+ if user_prompt.lower() == 'exit':
143
+ break
144
+
145
+ print("\n" + "-"*50)
146
+
147
+ # Generate responses from each model
148
+ for model_name, model in models.items():
149
+ print(f"\n{model_name} response:")
150
+ response, was_truncated = stream_response(model, tokenizer, user_prompt)
151
+ if was_truncated:
152
+ print("\n[Note: Response was truncated at a natural stopping point]")
153
+ print("\n" + "-"*30)
154
+
155
+ print("-"*50)
156
+
157
+ if __name__ == "__main__":
158
+ main()