File size: 5,269 Bytes
68a2fd4
b507cf2
e6e8f08
b507cf2
e6e8f08
 
68a2fd4
e6e8f08
 
 
 
 
 
68a2fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b507cf2
 
 
 
68a2fd4
 
 
 
 
 
 
 
 
 
e6e8f08
b507cf2
 
 
 
 
 
 
 
 
e6e8f08
 
 
 
b507cf2
e6e8f08
 
 
b507cf2
 
 
 
 
 
 
 
 
 
 
 
 
68a2fd4
e6e8f08
 
 
 
b507cf2
 
 
 
e6e8f08
 
b507cf2
ad4670f
ace32a8
ad4670f
e6e8f08
 
 
ad4670f
 
e6e8f08
 
ace32a8
e6e8f08
 
 
b507cf2
e6e8f08
 
b507cf2
e6e8f08
 
 
68a2fd4
 
e6e8f08
 
68a2fd4
 
 
e6e8f08
ad4670f
e6e8f08
 
b507cf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig, get_peft_model
import torch
import os
from collections import defaultdict

# Set environment variables for better logging
os.environ["WANDB_PROJECT"] = "phi2-grpo-finetuning"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Load the OpenAssistant dataset
raw_data = load_dataset("OpenAssistant/oasst1", split="train")

# Preprocess the dataset using logic from preprocess.py
# Group messages by conversation_id
conversations = defaultdict(list)
for item in raw_data:
    conversations[item["message_tree_id"]].append(item)

# Prepare preference pairs
pairs = []
for tree_id, msgs in conversations.items():
    prompt = next((m for m in msgs if m["role"] == "prompter" and m["parent_id"] is None), None)
    if not prompt:
        continue
    
    # Find direct replies to the prompt
    replies = [m for m in msgs if m["parent_id"] == prompt["message_id"]]
    
    # If we don't have ranking info or not enough replies, try to use other heuristics
    if len([r for r in replies if r.get("ranking")]) < 2:
        # If we have at least 2 replies, use them based on likes or other metrics
        if len(replies) >= 2:
            # Sort by likes if available, otherwise just take any two
            if all("like_count" in r for r in replies):
                ranked = sorted(replies, key=lambda x: x.get("like_count", 0), reverse=True)
            else:
                ranked = replies[:2]  # Just take the first two
            
            chosen = ranked[0]["text"]
            rejected = ranked[-1]["text"]
            
            pairs.append({
                "prompt": prompt["text"],
                "chosen": chosen,
                "rejected": rejected
            })
        continue
    
    # Original logic for replies with ranking
    ranked = sorted(replies, key=lambda x: x["ranking"])
    chosen = ranked[0]["text"]
    rejected = ranked[-1]["text"]

    pairs.append({
        "prompt": prompt["text"],
        "chosen": chosen,
        "rejected": rejected
    })

# Convert to Hugging Face dataset format for preference learning
preference_dataset = Dataset.from_list(pairs)

# Limit dataset size to speed up training (use first 1000 examples)
if len(preference_dataset) > 1000:
    preference_dataset = preference_dataset.select(range(1000))

print(f"Created {len(preference_dataset)} preference pairs for GRPO")

# Debug: Print a sample pair if available
if len(preference_dataset) > 0:
    print("\nSample preference pair:")
    print(f"Prompt: {preference_dataset[0]['prompt'][:100]}...")
    print(f"Chosen: {preference_dataset[0]['chosen'][:100]}...")
    print(f"Rejected: {preference_dataset[0]['rejected'][:100]}...")
else:
    print("WARNING: No preference pairs were created. Check the dataset structure.")

# Configure quantization for loading the model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer with quantization
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=quantization_config,
    device_map="auto"
)

# Configure LoRA
peft_config = LoraConfig(
    r=16,                     # Rank
    lora_alpha=32,            # Alpha parameter for LoRA scaling
    lora_dropout=0.05,        # Dropout probability for LoRA layers
    bias="none",              # Bias type for LoRA
    task_type="CAUSAL_LM",    # Task type
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

# Apply LoRA to the model
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()  # Print trainable parameters info

# Configure tokenizer for chat format
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Define a reward function that rewards helpful, concise responses
def reward_func(completions, **kwargs):
    return [len(c.split()) for c in completions]  # reward by word count

# Configure GRPO training
training_args = GRPOConfig(
    output_dir="phi2-grpo-qlora",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-6,
    logging_steps=10,
    save_steps=10,                    # Save every 10 steps
    save_total_limit=1,               # Keep only 1 checkpoint (overwrite previous ones)
    fp16=True,
    remove_unused_columns=False,
    report_to="none",
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    num_generations=2,
)

# Initialize the GRPO trainer
trainer = GRPOTrainer(
    model=model,
    args=training_args,
    train_dataset=preference_dataset,
    reward_funcs=reward_func,
)

# Set the tokenizer on the trainer after initialization
trainer.tokenizer = tokenizer

# Start training
trainer.train(resume_from_checkpoint=True)  # Resume from the latest checkpoint

# Save the final model
trainer.save_model("phi2-grpo-qlora-final")