|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import argparse |
|
import time |
|
import sys |
|
|
|
def stream_response(model, tokenizer, prompt, max_new_tokens=256): |
|
"""Generate a streaming response from the model for the given prompt.""" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
input_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) |
|
|
|
|
|
generated_ids = inputs.input_ids.clone() |
|
past_key_values = None |
|
response_text = "" |
|
|
|
|
|
stop_sequences = ["\n\n", "\nExercise:", "\nQuestion:"] |
|
was_truncated = False |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
with torch.no_grad(): |
|
|
|
outputs = model( |
|
input_ids=generated_ids[:, -1:] if past_key_values is not None else generated_ids, |
|
past_key_values=past_key_values, |
|
use_cache=True |
|
) |
|
|
|
|
|
logits = outputs.logits |
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
next_token_logits = next_token_logits / 0.7 |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > 0.9 |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
next_token_logits[indices_to_remove] = -float('Inf') |
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
|
|
|
|
|
current_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if len(current_text) > input_length: |
|
new_text = current_text[input_length:] |
|
|
|
new_chars = new_text[len(response_text):] |
|
sys.stdout.write(new_chars) |
|
sys.stdout.flush() |
|
response_text = new_text |
|
|
|
|
|
time.sleep(0.01) |
|
|
|
|
|
if next_token[0, 0].item() == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
if any(stop_seq in response_text for stop_seq in stop_sequences): |
|
|
|
for stop_seq in stop_sequences: |
|
if stop_seq in response_text: |
|
stop_idx = response_text.find(stop_seq) |
|
if stop_idx > 0: |
|
was_truncated = True |
|
response_text = response_text[:stop_idx] |
|
sys.stdout.write("\n") |
|
sys.stdout.flush() |
|
return response_text, was_truncated |
|
|
|
|
|
return response_text, was_truncated |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Compare base and fine-tuned Phi-2 models") |
|
parser.add_argument("--base-only", action="store_true", help="Use only the base model") |
|
parser.add_argument("--finetuned-only", action="store_true", help="Use only the fine-tuned model") |
|
parser.add_argument("--adapter-path", type=str, default="./phi2-grpo-qlora-final", |
|
help="Path to the fine-tuned adapter") |
|
args = parser.parse_args() |
|
|
|
|
|
base_model_name = "microsoft/phi-2" |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
models = {} |
|
|
|
if not args.finetuned_only: |
|
print("Loading base model...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
models["Base Phi-2"] = base_model |
|
|
|
if not args.base_only: |
|
print(f"Loading fine-tuned model from {args.adapter_path}...") |
|
|
|
base_model_for_ft = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
|
|
finetuned_model = PeftModel.from_pretrained(base_model_for_ft, args.adapter_path) |
|
models["Fine-tuned Phi-2"] = finetuned_model |
|
|
|
|
|
print("\n" + "="*50) |
|
print("Interactive Phi-2 Model Comparison (Streaming Mode)") |
|
print("Type 'exit' to quit") |
|
print("="*50 + "\n") |
|
|
|
while True: |
|
|
|
user_prompt = input("\nEnter your prompt: ") |
|
if user_prompt.lower() == 'exit': |
|
break |
|
|
|
print("\n" + "-"*50) |
|
|
|
|
|
for model_name, model in models.items(): |
|
print(f"\n{model_name} response:") |
|
response, was_truncated = stream_response(model, tokenizer, user_prompt) |
|
if was_truncated: |
|
print("\n[Note: Response was truncated at a natural stopping point]") |
|
print("\n" + "-"*30) |
|
|
|
print("-"*50) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|