Multi-Token Prediction GPT-2 Model
This is a GPT-2 model enhanced with Multi-Token Prediction (MTP) architecture, trained on the MetaMathQA dataset for mathematical reasoning tasks.
Model Description
This model implements the Multi-Token Prediction approach from the paper "Better & Faster Large Language Models via Multi-token Prediction" by Meta AI. Key features:
- Shared Trunk Architecture: Uses 9 shared transformer layers with 4 prediction heads
- Multi-Token Prediction: Predicts 4 tokens simultaneously (t+1, t+2, t+3, t+4)
- Enhanced Speculative Decoding: Achieves up to 3x inference speedup
- Mathematical Reasoning: Fine-tuned specifically for mathematical problem solving
Architecture Details
- Base Model: GPT-2 (124M parameters)
- Trunk Layers: 9 (shared processing)
- Prediction Heads: 4 parallel heads
- Training Data: MetaMathQA dataset (500 samples)
- Training Epochs: 1 with gradient accumulation
Usage
import torch
from transformers import GPT2Tokenizer
# Note: You'll need the custom MultiTokenGPT2 class from the training code
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# Load your model (custom loading required)
# model = MultiTokenGPT2.from_pretrained("Goldenwert/multitoken-gpt2-metamathqa")
prompt = "What is the derivative of f(x) = x^3?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Standard generation
generated = model.generate(
input_ids,
max_new_tokens=50,
use_speculative=False
)
# Fast speculative generation
generated_fast = model.generate(
input_ids,
max_new_tokens=50,
use_speculative=True
)
print(tokenizer.decode(generated[0], skip_special_tokens=True))
Performance
- Inference Speed: Up to 3x faster with speculative decoding
- Memory Efficiency: Gradient checkpointing support
- Mathematical Tasks: Improved reasoning on math problems
Training Details
- Dataset: MetaMathQA (mathematical reasoning)
- Optimizer: AdamW with warmup
- Learning Rate: 5e-5
- Batch Size: 4 (effective 32 with gradient accumulation)
- Hardware: GPU with FP16 precision
Research Context
Based on the paper "Better & Faster Large Language Models via Multi-token Prediction" which demonstrates that predicting multiple tokens simultaneously can improve both training efficiency and inference speed.
Files
pytorch_model.bin
: Model weightsconfig.json
: Model configuration- Additional training artifacts
License
Apache 2.0 (following GPT-2 base model)
- Downloads last month
- 5
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
Model tree for Goldenwert/multitoken-gpt2-metamathqa
Base model
openai-community/gpt2