Ubuntu commited on
Commit
e6e8f08
·
0 Parent(s):
Files changed (2) hide show
  1. GRPO.py +103 -0
  2. preprocess.py +0 -0
GRPO.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from trl import GRPOConfig, GRPOTrainer
4
+ import torch
5
+ import os
6
+
7
+ # Set environment variables for better logging
8
+ os.environ["WANDB_PROJECT"] = "phi2-grpo-finetuning"
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+
11
+ # Load the OpenAssistant dataset
12
+ dataset = load_dataset("trl-internal-testing/oasst_preference_dataset", split="train")
13
+
14
+ # Load model and tokenizer
15
+ model_name = "microsoft/phi-2"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="auto"
21
+ )
22
+
23
+ # Configure tokenizer for chat format
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ tokenizer.padding_side = "left"
26
+
27
+ # Process the dataset to create prompt-response pairs
28
+ def preprocess_function(examples):
29
+ # For OpenAssistant, we need to format the conversations properly
30
+ # This is a simplified version - you may need to adjust based on the exact structure
31
+ prompts = []
32
+ responses = []
33
+
34
+ for message in examples["messages"]:
35
+ if len(message) >= 2: # Ensure there's at least a prompt and response
36
+ prompt = message[0]["content"]
37
+ response = message[1]["content"]
38
+ prompts.append(prompt)
39
+ responses.append(response)
40
+
41
+ return {"prompt": prompts, "response": responses}
42
+
43
+ # Process the dataset
44
+ processed_dataset = dataset.map(
45
+ preprocess_function,
46
+ batched=True,
47
+ remove_columns=dataset.column_names
48
+ )
49
+
50
+ # Define a reward function that rewards helpful, concise responses
51
+ def reward_function(responses, prompts=None, **kwargs):
52
+ rewards = []
53
+ for response in responses:
54
+ # Example reward criteria:
55
+ # 1. Length-based component (prefer responses between 100-500 chars)
56
+ length_score = min(1.0, max(0.0, 1.0 - abs(len(response) - 300) / 300))
57
+
58
+ # 2. Quality heuristics (simple examples)
59
+ has_structure = 0.5 if any(marker in response for marker in ["First", "Second", "Finally", "In conclusion"]) else 0.0
60
+ is_detailed = 0.5 if len(response) > 200 else 0.0
61
+
62
+ # Combine reward components
63
+ reward = length_score + has_structure + is_detailed
64
+ rewards.append(reward)
65
+
66
+ return rewards
67
+
68
+ # Configure GRPO training
69
+ training_args = GRPOConfig(
70
+ output_dir="phi2-grpo-openassistant",
71
+ num_train_epochs=3,
72
+ per_device_train_batch_size=4,
73
+ gradient_accumulation_steps=4,
74
+ gradient_checkpointing=True,
75
+ learning_rate=5e-6,
76
+ max_length=512,
77
+ logging_steps=10,
78
+ save_steps=100,
79
+ eval_steps=100,
80
+ evaluation_strategy="steps",
81
+ fp16=True,
82
+ remove_unused_columns=False,
83
+ report_to="wandb",
84
+ optim="adamw_torch",
85
+ lr_scheduler_type="cosine",
86
+ warmup_ratio=0.1,
87
+ )
88
+
89
+ # Initialize the GRPO trainer
90
+ trainer = GRPOTrainer(
91
+ model=model,
92
+ tokenizer=tokenizer,
93
+ args=training_args,
94
+ train_dataset=processed_dataset,
95
+ reward_funcs=reward_function,
96
+ packing=False,
97
+ )
98
+
99
+ # Start training
100
+ trainer.train()
101
+
102
+ # Save the final model
103
+ trainer.save_model("phi2-grpo-openassistant-final")
preprocess.py ADDED
File without changes