Ubuntu commited on
Commit
68a2fd4
·
1 Parent(s): e6e8f08
Files changed (2) hide show
  1. GRPO.py +76 -52
  2. preprocess.py +40 -0
GRPO.py CHANGED
@@ -1,15 +1,77 @@
1
- from datasets import load_dataset
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from trl import GRPOConfig, GRPOTrainer
4
  import torch
5
  import os
 
6
 
7
  # Set environment variables for better logging
8
  os.environ["WANDB_PROJECT"] = "phi2-grpo-finetuning"
9
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
 
11
  # Load the OpenAssistant dataset
12
- dataset = load_dataset("trl-internal-testing/oasst_preference_dataset", split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load model and tokenizer
15
  model_name = "microsoft/phi-2"
@@ -20,64 +82,25 @@ model = AutoModelForCausalLM.from_pretrained(
20
  device_map="auto"
21
  )
22
 
 
 
 
 
 
23
  # Configure tokenizer for chat format
24
  tokenizer.pad_token = tokenizer.eos_token
25
  tokenizer.padding_side = "left"
26
 
27
- # Process the dataset to create prompt-response pairs
28
- def preprocess_function(examples):
29
- # For OpenAssistant, we need to format the conversations properly
30
- # This is a simplified version - you may need to adjust based on the exact structure
31
- prompts = []
32
- responses = []
33
-
34
- for message in examples["messages"]:
35
- if len(message) >= 2: # Ensure there's at least a prompt and response
36
- prompt = message[0]["content"]
37
- response = message[1]["content"]
38
- prompts.append(prompt)
39
- responses.append(response)
40
-
41
- return {"prompt": prompts, "response": responses}
42
-
43
- # Process the dataset
44
- processed_dataset = dataset.map(
45
- preprocess_function,
46
- batched=True,
47
- remove_columns=dataset.column_names
48
- )
49
-
50
- # Define a reward function that rewards helpful, concise responses
51
- def reward_function(responses, prompts=None, **kwargs):
52
- rewards = []
53
- for response in responses:
54
- # Example reward criteria:
55
- # 1. Length-based component (prefer responses between 100-500 chars)
56
- length_score = min(1.0, max(0.0, 1.0 - abs(len(response) - 300) / 300))
57
-
58
- # 2. Quality heuristics (simple examples)
59
- has_structure = 0.5 if any(marker in response for marker in ["First", "Second", "Finally", "In conclusion"]) else 0.0
60
- is_detailed = 0.5 if len(response) > 200 else 0.0
61
-
62
- # Combine reward components
63
- reward = length_score + has_structure + is_detailed
64
- rewards.append(reward)
65
-
66
- return rewards
67
-
68
  # Configure GRPO training
69
  training_args = GRPOConfig(
70
  output_dir="phi2-grpo-openassistant",
71
  num_train_epochs=3,
72
- per_device_train_batch_size=4,
73
  gradient_accumulation_steps=4,
74
  gradient_checkpointing=True,
75
  learning_rate=5e-6,
76
- max_length=512,
77
  logging_steps=10,
78
  save_steps=100,
79
- eval_steps=100,
80
- evaluation_strategy="steps",
81
  fp16=True,
82
  remove_unused_columns=False,
83
  report_to="wandb",
@@ -86,16 +109,17 @@ training_args = GRPOConfig(
86
  warmup_ratio=0.1,
87
  )
88
 
89
- # Initialize the GRPO trainer
90
  trainer = GRPOTrainer(
91
  model=model,
92
- tokenizer=tokenizer,
93
  args=training_args,
94
- train_dataset=processed_dataset,
95
- reward_funcs=reward_function,
96
- packing=False,
97
  )
98
 
 
 
 
99
  # Start training
100
  trainer.train()
101
 
 
1
+ from datasets import load_dataset, Dataset
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from trl import GRPOConfig, GRPOTrainer
4
  import torch
5
  import os
6
+ from collections import defaultdict
7
 
8
  # Set environment variables for better logging
9
  os.environ["WANDB_PROJECT"] = "phi2-grpo-finetuning"
10
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
 
12
  # Load the OpenAssistant dataset
13
+ raw_data = load_dataset("OpenAssistant/oasst1", split="train")
14
+
15
+ # Preprocess the dataset using logic from preprocess.py
16
+ # Group messages by conversation_id
17
+ conversations = defaultdict(list)
18
+ for item in raw_data:
19
+ conversations[item["message_tree_id"]].append(item)
20
+
21
+ # Prepare preference pairs
22
+ pairs = []
23
+ for tree_id, msgs in conversations.items():
24
+ prompt = next((m for m in msgs if m["role"] == "prompter" and m["parent_id"] is None), None)
25
+ if not prompt:
26
+ continue
27
+
28
+ # Find direct replies to the prompt
29
+ replies = [m for m in msgs if m["parent_id"] == prompt["message_id"]]
30
+
31
+ # If we don't have ranking info or not enough replies, try to use other heuristics
32
+ if len([r for r in replies if r.get("ranking")]) < 2:
33
+ # If we have at least 2 replies, use them based on likes or other metrics
34
+ if len(replies) >= 2:
35
+ # Sort by likes if available, otherwise just take any two
36
+ if all("like_count" in r for r in replies):
37
+ ranked = sorted(replies, key=lambda x: x.get("like_count", 0), reverse=True)
38
+ else:
39
+ ranked = replies[:2] # Just take the first two
40
+
41
+ chosen = ranked[0]["text"]
42
+ rejected = ranked[-1]["text"]
43
+
44
+ pairs.append({
45
+ "prompt": prompt["text"],
46
+ "chosen": chosen,
47
+ "rejected": rejected
48
+ })
49
+ continue
50
+
51
+ # Original logic for replies with ranking
52
+ ranked = sorted(replies, key=lambda x: x["ranking"])
53
+ chosen = ranked[0]["text"]
54
+ rejected = ranked[-1]["text"]
55
+
56
+ pairs.append({
57
+ "prompt": prompt["text"],
58
+ "chosen": chosen,
59
+ "rejected": rejected
60
+ })
61
+
62
+ # Convert to Hugging Face dataset format for preference learning
63
+ preference_dataset = Dataset.from_list(pairs)
64
+
65
+ print(f"Created {len(preference_dataset)} preference pairs for GRPO")
66
+
67
+ # Debug: Print a sample pair if available
68
+ if len(preference_dataset) > 0:
69
+ print("\nSample preference pair:")
70
+ print(f"Prompt: {preference_dataset[0]['prompt'][:100]}...")
71
+ print(f"Chosen: {preference_dataset[0]['chosen'][:100]}...")
72
+ print(f"Rejected: {preference_dataset[0]['rejected'][:100]}...")
73
+ else:
74
+ print("WARNING: No preference pairs were created. Check the dataset structure.")
75
 
76
  # Load model and tokenizer
77
  model_name = "microsoft/phi-2"
 
82
  device_map="auto"
83
  )
84
 
85
+ # Define a reward function that rewards helpful, concise responses
86
+ # and penalizes responses similar to rejected ones
87
+ def reward_func(completions, **kwargs):
88
+ return [len(c.split()) for c in completions] # reward by word count
89
+
90
  # Configure tokenizer for chat format
91
  tokenizer.pad_token = tokenizer.eos_token
92
  tokenizer.padding_side = "left"
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Configure GRPO training
95
  training_args = GRPOConfig(
96
  output_dir="phi2-grpo-openassistant",
97
  num_train_epochs=3,
98
+ per_device_train_batch_size=8,
99
  gradient_accumulation_steps=4,
100
  gradient_checkpointing=True,
101
  learning_rate=5e-6,
 
102
  logging_steps=10,
103
  save_steps=100,
 
 
104
  fp16=True,
105
  remove_unused_columns=False,
106
  report_to="wandb",
 
109
  warmup_ratio=0.1,
110
  )
111
 
112
+ # Initialize the GRPO trainer with preference dataset
113
  trainer = GRPOTrainer(
114
  model=model,
 
115
  args=training_args,
116
+ train_dataset=preference_dataset,
117
+ reward_funcs=reward_func,
 
118
  )
119
 
120
+ # Set the tokenizer on the trainer after initialization
121
+ trainer.tokenizer = tokenizer
122
+
123
  # Start training
124
  trainer.train()
125
 
preprocess.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, Dataset
2
+ from collections import defaultdict
3
+
4
+ # Load dataset
5
+ raw_data = load_dataset("OpenAssistant/oasst1", split="train")
6
+
7
+ # Group messages by conversation_id
8
+ conversations = defaultdict(list)
9
+ for item in raw_data:
10
+ conversations[item["message_tree_id"]].append(item)
11
+
12
+ # Prepare preference pairs
13
+ pairs = []
14
+ for tree_id, msgs in conversations.items():
15
+ prompt = next((m for m in msgs if m["role"] == "prompter" and m["parent_id"] is None), None)
16
+ if not prompt:
17
+ continue
18
+
19
+ # Find direct replies with ranking
20
+ replies = [m for m in msgs if m["parent_id"] == prompt["message_id"] and m.get("ranking")]
21
+
22
+ if len(replies) < 2:
23
+ continue
24
+
25
+ # Sort replies by rank
26
+ ranked = sorted(replies, key=lambda x: x["ranking"])
27
+
28
+ # Create one preference pair (you can create more pairs per prompt if you want)
29
+ chosen = ranked[0]["text"]
30
+ rejected = ranked[-1]["text"]
31
+
32
+ pairs.append({
33
+ "prompt": prompt["text"],
34
+ "chosen": chosen,
35
+ "rejected": rejected
36
+ })
37
+
38
+ # Convert to Hugging Face dataset format
39
+ preference_dataset = Dataset.from_list(pairs)
40
+ preference_dataset.save_to_disk("oasst_preference_for_grpo")