Multi-Token Prediction GPT-2 Model

This is a GPT-2 model enhanced with Multi-Token Prediction (MTP) architecture, trained on the MetaMathQA dataset for mathematical reasoning tasks.

Model Description

This model implements the Multi-Token Prediction approach from the paper "Better & Faster Large Language Models via Multi-token Prediction" by Meta AI. Key features:

  • Shared Trunk Architecture: Uses 9 shared transformer layers with 4 prediction heads
  • Multi-Token Prediction: Predicts 4 tokens simultaneously (t+1, t+2, t+3, t+4)
  • Enhanced Speculative Decoding: Achieves up to 3x inference speedup
  • Mathematical Reasoning: Fine-tuned specifically for mathematical problem solving

Architecture Details

  • Base Model: GPT-2 (124M parameters)
  • Trunk Layers: 9 (shared processing)
  • Prediction Heads: 4 parallel heads
  • Training Data: MetaMathQA dataset (500 samples)
  • Training Epochs: 1 with gradient accumulation

Usage

import torch
from transformers import GPT2Tokenizer
# Note: You'll need the custom MultiTokenGPT2 class from the training code

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load your model (custom loading required)
# model = MultiTokenGPT2.from_pretrained("Goldenwert/multitoken-gpt2-metamathqa")

prompt = "What is the derivative of f(x) = x^3?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# Standard generation
generated = model.generate(
    input_ids,
    max_new_tokens=50,
    use_speculative=False
)

# Fast speculative generation
generated_fast = model.generate(
    input_ids,
    max_new_tokens=50,
    use_speculative=True
)

print(tokenizer.decode(generated[0], skip_special_tokens=True))

Performance

  • Inference Speed: Up to 3x faster with speculative decoding
  • Memory Efficiency: Gradient checkpointing support
  • Mathematical Tasks: Improved reasoning on math problems

Training Details

  • Dataset: MetaMathQA (mathematical reasoning)
  • Optimizer: AdamW with warmup
  • Learning Rate: 5e-5
  • Batch Size: 4 (effective 32 with gradient accumulation)
  • Hardware: GPU with FP16 precision

Research Context

Based on the paper "Better & Faster Large Language Models via Multi-token Prediction" which demonstrates that predicting multiple tokens simultaneously can improve both training efficiency and inference speed.

Files

  • pytorch_model.bin: Model weights
  • config.json: Model configuration
  • Additional training artifacts

License

Apache 2.0 (following GPT-2 base model)

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Goldenwert/multitoken-gpt2-metamathqa

Finetuned
(1927)
this model