--- license: mit language: en base_model: meta-llama/Meta-Llama-3-8B-Instruct tags: - math - differential-equations - dpo - lbt - instruction-tuned --- # LMT-tuning: Llama-3-8B Fine-tuned for Differential Equations This model is a fine-tuned version of `meta-llama/Meta-Llama-3-8B-Instruct`, specialized for solving university-level differential equations problems. The model was trained using the **Learning by Teaching (LbT)** paradigm combined with **Direct Preference Optimization (DPO)**. This approach aims to improve a "teacher" model's reasoning capabilities by having it teach a "student" model and learning from the student's performance. ## Model Description The core idea of the training process was to create a high-quality preference dataset where the "better" response was not just more correct, but also a better piece of teaching material. The pipeline involved: 1. **Data Augmentation:** A raw corpus of ~1500 differential equations problems was flattened and structured into a training set (~1200 problems) and a test set (~300 problems). 2. **Teacher Generation:** The base Llama-3-8B model generated 32 step-by-step solutions (rationales) for each of the 1200 training problems. 3. **Student Examination (LbT Scoring):** For each of the ~39,000 generated rationales, a "student" model (also Llama-3-8B) was taught using that rationale as a one-shot example. The student then took a similarity-based exam, and its performance yielded an "LbT score" for the rationale. 4. **Preference Creation:** Rationales were scored based on a combination of correctness and their LbT score. High-scoring rationales were paired with low-scoring ones to create a preference dataset of `(prompt, chosen, rejected)` triplets. 5. **DPO Fine-tuning:** The base Llama-3-8B model was fine-tuned on this preference dataset using `trl`'s `DPOTrainer` and QLoRA. ## Intended Use This model is primarily intended for: - **Solving differential equations problems:** Providing step-by-step reasoning and a final answer. - **Educational purposes:** Serving as a tool for students to check their work and understand problem-solving steps. - **Research:** Acting as a baseline for further fine-tuning on specialized mathematical domains. **Note:** This is a specialist model. While it has been fine-tuned for differential equations, its capabilities on general-purpose chat or other reasoning tasks may have degraded. ## How to Use You can use this model with the `transformers` library pipeline. It is crucial to use the Llama 3 chat template for best results. ```python import torch from transformers import pipeline # Load the model and tokenizer pipe = pipeline( "text-generation", model="Sandesh-Zenteiq/LMT-tuning", torch_dtype=torch.bfloat16, device_map="auto" ) # Your differential equations problem problem = "Solve the initial value problem: y' - 2y = 0, with y(0) = 3." # This is the full instruction set the model was trained on instruction_text = ( "Your task is to answer the last question below. " "Give step by step reasoning before you answer. " "When you're ready to answer, please wrap your answer and conclude using the format\n" "'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" ) exam_template = ( "[[Question]]:\n{question}\n\n" "[[Solution]]:\nLet's think step by step.\n\n" ) # Format the prompt using the Llama 3 chat template prompt = ( f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" f"{instruction_text}{exam_template.format(question=problem)}" f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) # Generate the response # The pipeline will handle the prompt and only show you the generated part response = pipe( prompt, max_new_tokens=1024, do_sample=False, # Use do_sample=True for more creative answers temperature=0.7, top_p=0.9 ) # Extract and print the generated text # The pipeline returns a list of outputs generated_text = response['generated_text'] # The generated text includes the prompt, so we can slice it to see only the model's answer assistant_response = generated_text[len(prompt):] print(assistant_response) Training Details Base Model: meta-llama/Meta-Llama-3-8B-Instruct Framework: trl.DPOTrainer with QLoRA Hardware: NVIDIA A6000 / H200 class GPUs Key Hyperparameters: learning_rate: 2e-5 num_epochs: 1 lora_r: 128 lora_alpha: 256 gradient_accumulation_steps: 16 Evaluation The model was evaluated on a held-out test set of 305 differential equations problems that were not seen during training. The metric is Pass@1 accuracy. Model Accuracy meta-llama/Llama-3-8B-Instruct (Base) 10.16% LMT-tuning (This Model) 16.07% This represents a +5.90 point absolute improvement and a ~58% relative improvement in performance on this specialized task. Model fine-tuned by Sandesh-Zenteiq. The methodology is based on the paper "Can LLMs Learn by Teaching for Better Reasoning?"```