Text Generation
Transformers
Safetensors
English
ddllama
conversational
custom_code
xuan-luo's picture
Update training.py
13d7fa9 verified
import sys
import logging
import datasets
from datasets import load_dataset
import torch
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from typing import Dict, List
logger = logging.getLogger(__name__)
###################
# Hyper-parameters
###################
training_config = {
"bf16": True,
"do_eval": False,
"learning_rate": 1e-04,
"log_level": "info",
"logging_steps": 20,
"logging_strategy": "steps",
"lr_scheduler_type": "cosine",
"num_train_epochs": 3,
"max_steps": -1,
"output_dir": "./tulu_sft",
"overwrite_output_dir": True,
"per_device_eval_batch_size": 4,
"per_device_train_batch_size": 4,
"remove_unused_columns": True,
"save_steps": 1000,
"save_total_limit": 1,
"seed": 0,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs":{"use_reentrant": False},
"gradient_accumulation_steps": 4, # You may use a bigger batch size, and manually normalize the loss with gradient_accumulation_steps in modeling_ddllama.py
"warmup_ratio": 0.03,
"ddp_find_unused_parameters": True,
}
train_conf = TrainingArguments(**training_config)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = train_conf.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
+ f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
)
logger.info(f"Training/evaluation parameters {train_conf}")
################
# Model Loading
################
checkpoint_path = "./"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map=None
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer.model_max_length = 2048
tokenizer.pad_token = "<|reserved_special_token_0|>"
tokenizer.pad_token_id = 128002
tokenizer.padding_side = 'right'
##################
# Data Processing
##################
def apply_chat_template(
example,
tokenizer,
):
messages = example["messages"]
example["text"] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)
return example
raw_dataset = load_dataset("allenai/tulu-v2-sft-mixture")
train_dataset = raw_dataset["train"]
column_names = list(train_dataset.features)
processed_dataset = train_dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": tokenizer},
num_proc=64,
remove_columns=column_names,
desc="Applying chat template to train_sft",
)
###########
# Freeze Transformer
###########
for param in model.parameters():
param.requires_grad = False
for name, param in model.named_parameters():
if 'router' in name.lower():
param.requires_grad = True
###########
# Training
###########
trainer = SFTTrainer(
model=model,
args=train_conf,
peft_config=None,
train_dataset=processed_dataset,
eval_dataset=None,
max_seq_length=2048,
dataset_text_field="text",
tokenizer=tokenizer,
packing=True
)
train_result = trainer.train()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# ############
# # Save model
# ############
trainer.save_model(train_conf.output_dir)