import sys import logging import datasets from datasets import load_dataset import torch import transformers from trl import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig from typing import Dict, List logger = logging.getLogger(__name__) ################### # Hyper-parameters ################### training_config = { "bf16": True, "do_eval": False, "learning_rate": 1e-04, "log_level": "info", "logging_steps": 20, "logging_strategy": "steps", "lr_scheduler_type": "cosine", "num_train_epochs": 3, "max_steps": -1, "output_dir": "./tulu_sft", "overwrite_output_dir": True, "per_device_eval_batch_size": 4, "per_device_train_batch_size": 4, "remove_unused_columns": True, "save_steps": 1000, "save_total_limit": 1, "seed": 0, "gradient_checkpointing": True, "gradient_checkpointing_kwargs":{"use_reentrant": False}, "gradient_accumulation_steps": 4, # You may use a bigger batch size, and manually normalize the loss with gradient_accumulation_steps in modeling_ddllama.py "warmup_ratio": 0.03, "ddp_find_unused_parameters": True, } train_conf = TrainingArguments(**training_config) ############### # Setup logging ############### logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) log_level = train_conf.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Log on each process a small summary logger.warning( f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}" + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}" ) logger.info(f"Training/evaluation parameters {train_conf}") ################ # Model Loading ################ checkpoint_path = "./" model_kwargs = dict( use_cache=False, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map=None ) model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") tokenizer.model_max_length = 2048 tokenizer.pad_token = "<|reserved_special_token_0|>" tokenizer.pad_token_id = 128002 tokenizer.padding_side = 'right' ################## # Data Processing ################## def apply_chat_template( example, tokenizer, ): messages = example["messages"] example["text"] = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False) return example raw_dataset = load_dataset("allenai/tulu-v2-sft-mixture") train_dataset = raw_dataset["train"] column_names = list(train_dataset.features) processed_dataset = train_dataset.map( apply_chat_template, fn_kwargs={"tokenizer": tokenizer}, num_proc=64, remove_columns=column_names, desc="Applying chat template to train_sft", ) ########### # Freeze Transformer ########### for param in model.parameters(): param.requires_grad = False for name, param in model.named_parameters(): if 'router' in name.lower(): param.requires_grad = True ########### # Training ########### trainer = SFTTrainer( model=model, args=train_conf, peft_config=None, train_dataset=processed_dataset, eval_dataset=None, max_seq_length=2048, dataset_text_field="text", tokenizer=tokenizer, packing=True ) train_result = trainer.train() metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() # ############ # # Save model # ############ trainer.save_model(train_conf.output_dir)