RelaxingSnorlax's picture
Upload folder using huggingface_hub
8867025 verified
"""
Speculators implementations providing a unified implementation
for EAGLE v1, EAGLE v2, and HASS variants for spec decoding:
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
- Eagle v2: https://arxiv.org/abs/2406.16858
- HASS: https://arxiv.org/abs/2408.15766
Classes:
EagleSpeculatorConfig: Configuration class for EAGLE/HASS model variants
EagleSpeculator: Main model implementation for EAGLE/HASS speculators
"""
import os
from typing import Any, ClassVar, Literal, Optional, Union
import torch
from pydantic import Field, field_serializer, field_validator, model_validator
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
)
from typing_extensions import Self
from speculators import SpeculatorModel, SpeculatorModelConfig
__all__ = [
"EagleSpeculator",
"EagleSpeculatorConfig",
]
@SpeculatorModelConfig.register("eagle")
class EagleSpeculatorConfig(SpeculatorModelConfig):
"""
A SpeculatorModelConfig implementation to be used with the EagleSpeculator
for EAGLE and HASS variants for spec decoding:
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
- Eagle v2: https://arxiv.org/abs/2406.16858
- HASS: https://arxiv.org/abs/2408.15766
Model Configurations:
- EAGLE1: layernorms=False, fusion_bias=False
- EAGLE2: layernorms=False, fusion_bias=False
- HASS: layernorms=False, fusion_bias=True
Example:
```python
from speculators import SpeculatorsConfig, VerifierConfig
from speculators.models import EagleSpeculatorConfig
from speculators.proposals import GreedyTokenProposalConfig
from transformers import AutoConfig
config = EagleSpeculatorConfig(
transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
speculators_config=SpeculatorsConfig(
algorithm="eagle",
proposal_methods=[
GreedyTokenProposalConfig(),
],
default_proposal_method="greedy",
verifier=VerifierConfig(
name_or_path="meta-llama/Llama-3.1-8B-Instruct",
architectures=["LlamaForCausalLM"],
)
)
```
"""
speculators_model_type: Literal["eagle"] = "eagle"
architectures: list[str] = Field(
default_factory=lambda: ["EagleSpeculator"],
description=(
"List of model architectures that can be used with the model "
"pretrained weights. Automatically includes the transformer layer "
"architecture to ensure compatibility during model loading and "
"validation."
),
)
transformer_layer_architecture: str = Field(
default="LlamaDecoderLayer",
description=(
"The architecture class name of the transformer layer to use for "
"the speculator's decoder layer. Must correspond to a valid "
"transformer decoder layer class (e.g., 'LlamaDecoderLayer')."
),
)
transformer_layer_config: PretrainedConfig = Field(
default_factory=LlamaConfig,
description=(
"Configuration object for the transformer layer architecture. "
"Must be a PretrainedConfig instance that matches the requirements "
"of the transformer_layer_architecture. Contains parameters such as "
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
"and other architecture-specific settings."
),
)
layernorms: bool = Field(
default=False,
description=(
"Whether to include additional layer normalization layers in the "
"model architecture. When True, adds RMSNorm layers after the "
"verifier's hidden state (embedding_layernorm), after the fusion "
"layer output, and before the language model head (pre_lm_head_layernorm). "
"When False, these layers are not included and the output layernorm "
"within the transformer architecture is removed as well. "
"Standard EAGLE1, EAGLE2, and HASS implementations use False."
),
)
fusion_bias: bool = Field(
default=False,
description=(
"Whether to add a learnable bias term to the fusion (fully connected) "
"layer that combines input embeddings with verifier hidden states. "
"The fusion layer concatenates input embeddings and hidden states, "
"then projects to hidden_size dimensions. Standard EAGLE1 and EAGLE2 "
"use False, while HASS uses True."
),
)
@model_validator(mode="after")
def check_add_architectures(self) -> Self:
"""
Automatically adds the transformer layer architecture to the
architectures list if it's not already present.
:return: The validated configuration instance with updated architectures
"""
if self.transformer_layer_architecture not in self.architectures:
self.architectures.append(self.transformer_layer_architecture)
return self
@field_serializer("transformer_layer_config")
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
"""
Serialize the transformer_layer_config to a dictionary for JSON storage.
Converts the PretrainedConfig object to its dictionary representation
using to_diff_dict() to only include non-default values.
:param value: The PretrainedConfig instance to serialize
:return: Dictionary representation of the transformer layer configuration
"""
return value.to_diff_dict()
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
"""
Validate and convert transformer_layer_config to a PretrainedConfig instance.
Accepts either a dictionary that can be converted to a PretrainedConfig
or an existing PretrainedConfig instance.
:param value: The value to validate (dict or PretrainedConfig)
:return: A validated PretrainedConfig instance
:raises ValueError: If the value cannot be converted to a PretrainedConfig
"""
if isinstance(value, dict):
return PretrainedConfig.from_dict(value)
if isinstance(value, PretrainedConfig):
return value
raise ValueError(
"transformer_layer_config must be a PretrainedConfig instance or a "
"dictionary that can be converted to a PretrainedConfig."
)
@SpeculatorModel.register("eagle")
class EagleSpeculator(SpeculatorModel):
"""
A SpeculatorModel implementation for EAGLE and HASS variants for spec decoding:
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
- Eagle v2: https://arxiv.org/abs/2406.16858
- HASS: https://arxiv.org/abs/2408.15766
Architecture Overview:
The EAGLE speculator consists of:
1. Input embedding layer (shared with verifier)
2. Optional embedding layer normalization
3. Fusion layer: Concatenates and projects input embeddings + verifier hidden
states to a latent space of hidden_size
4. Single transformer decoder layer for candidate token generation
5. Optional pre-LM head layer normalization
6. Language model head (shared with verifier)
Speculative Decoding Process:
1. Verifier model processes input and generates hidden states
2. EAGLE speculator uses these hidden states + input embeddings to predict
next tokens
3. Multiple candidate tokens generated in parallel using token proposal methods
4. Verifier validates candidates and accepts/rejects based on probability
thresholds
5. Process continues iteratively for multi-token speculation
Example:
```python
from speculators import SpeculatorsConfig, VerifierConfig
from speculators.models import EagleSpeculator, EagleSpeculatorConfig
from speculators.proposals import GreedyTokenProposalConfig
from transformers import AutoConfig, AutoTokenizer
config = EagleSpeculatorConfig(
transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
speculators_config=SpeculatorsConfig(
algorithm="eagle",
proposal_methods=[
GreedyTokenProposalConfig(),
],
default_proposal_method="greedy",
verifier=VerifierConfig(
name_or_path="meta-llama/Llama-3.1-8B-Instruct",
architectures=["LlamaForCausalLM"],
)
)
speculator = EagleSpeculator(
config, verifier=verifier, verifier_attachment_mode="full"
)
```
"""
# PreTrainedModel settings
config_class: ClassVar[type[EagleSpeculatorConfig]] = EagleSpeculatorConfig # type: ignore[misc]
_keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
"verifier*",
"embed_tokens*",
"lm_head*",
]
_keys_to_ignore_on_save: ClassVar[list[str]] = [ # type: ignore[assignment,misc]
"embed_tokens.weight",
"lm_head.weight",
"lm_head.bias",
]
def __init__(
self,
config: EagleSpeculatorConfig,
verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
verifier_attachment_mode: Optional[
Literal["detached", "full", "train_only"]
] = None,
):
"""
Initializes an EAGLE speculator architecture with configurable components based
on the provided configuration. The model starts with verifier-dependent layers
(embed_tokens, rotary_emb, lm_head) set to None until a verifier is attached.
:param config: Configuration object specifying model architecture, layer
settings, and speculative decoding parameters. Must be an instance of
EagleSpeculatorConfig containing transformer layer configuration and
EAGLE-specific settings.
:param verifier: Optional verifier model to attach for speculative decoding.
Can be a path to a model directory, Hugging Face model identifier, or
PreTrainedModel instance. If None, must be attached later via
attach_verifier() before using the model.
:param verifier_attachment_mode: Mode for verifier attachment. "detached"
prevents attachment even if verifier is provided. "full" enables
complete integration for both training and generation. "train_only"
attaches only components needed for training, optimizing memory usage.
"""
if not isinstance(config, EagleSpeculatorConfig):
raise ValueError(
"config must be an instance of EagleSpeculatorConfig, "
f"got {type(config)} instead."
)
# Initialize model parameters from config
self.vocab_size = config.transformer_layer_config.vocab_size
self.hidden_size = config.transformer_layer_config.hidden_size
self.padding_idx = config.transformer_layer_config.pad_token_id
# Set layers pulled from the verifier to None until attach is called
self.embed_tokens: Optional[nn.Embedding] = None
self.rotary_emb: Optional[nn.Module] = None
self.lm_head: Optional[nn.Linear] = None
# Delayed initialization to ensure everything needed for attach_verifier is set
super().__init__(
config=config,
verifier=verifier,
verifier_attachment_mode=verifier_attachment_mode,
)
# Initialize layers based on the configuration
self.embedding_layernorm: Optional[nn.Module] = self._create_layernorm()
self.fusion_fc: nn.Linear = nn.Linear(
2 * self.hidden_size,
self.hidden_size,
bias=config.fusion_bias,
)
self.transformer: nn.Module = self._create_transformer_layer()
self.pre_lm_head_layernorm: Optional[nn.Module] = self._create_layernorm()
self.post_init() # type: ignore[attr-defined]
def attach_verifier(
self,
verifier: Union[str, os.PathLike, PreTrainedModel],
mode: Optional[Literal["full", "train_only"]] = None,
) -> PreTrainedModel:
"""
Attach a verifier model to the EagleSpeculator for speculative decoding.
Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers
for the speculator's forward pass and generation methods.
Additionally, for `generate`, it uses the verifier's hidden states
to generate speculative token predictions.
If mode is "full", the verifier is fully integrated for use with
both `generate` and `forward` methods.
If mode is "train_only", only the verifier's layers required for a forward pass
are attached, allowing for better resource utilization during training.
`generate` will not be available until a full verifier is attached.
Example:
```python
# Load and attach a verifier
verifier = EagleSpeculator(...)
# For generation
speculator.attach_verifier(verifier)
outputs = speculator.generate(input_ids)
speculator.detach_verifier()
# For training
speculator.attach_verifier(verifier, mode="train_only")
outputs = speculator(input_ids, hidden_states)
speculator.detach_verifier()
```
:param verifier: The verifier model to attach. This can be a path to a local
model directory, a Hugging Face model identifier, or an instance of
PreTrainedModel. If a path or identifier is provided, the model will be
loaded automatically. If an instance is provided, it will be used directly.
:param mode: The mode for attaching the verifier. Can be "full" or "train_only".
If None, defaults to "full". In "train_only" mode, only the layers
required for a forward pass are attached, and the speculator cannot
perform generation until a full verifier is attached.
:return: The PreTrainedModel instance for the verifier that was attached.
"""
verifier = super().attach_verifier(
verifier=verifier,
mode=mode,
)
# Extract layers from the verifier model
if hasattr(verifier, "model"):
self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment]
self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment]
else:
# Bare model structure
self.embed_tokens = verifier.embed_tokens # type: ignore[assignment]
self.rotary_emb = verifier.rotary_emb # type: ignore[assignment]
# lm_head is always at the top level of the verifier
self.lm_head = verifier.lm_head
return verifier
def detach_verifier(self):
"""
Removes the reference to the attached verifier model and frees up the
associated memory. After calling this method, the speculator will not
be able to perform forward passes or generation until a new verifier
is attached.
"""
super().detach_verifier()
del self.embed_tokens
self.embed_tokens = None
del self.rotary_emb
self.rotary_emb = None
del self.lm_head
self.lm_head = None
def forward(
self,
input_ids: torch.LongTensor,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, # noqa: ARG002
return_dict: Optional[bool] = None,
) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
"""
Execute the forward pass for speculative token generation.
Processes input tokens and verifier hidden states through the EAGLE architecture
to generate candidate tokens for speculative decoding. The method combines input
embeddings with verifier hidden states via a fusion layer, processes them
through a transformer decoder layer, and produces logits for next token
prediction.
:param input_ids: Token IDs for the current input sequence. Shape: (batch_size,
sequence_length). These represent the tokens that will be converted to
embeddings and combined with verifier hidden states.
:param hidden_states: Hidden state representations from the verifier model
corresponding to the input sequence. Shape: (batch_size, sequence_length,
hidden_size). These capture the verifier's understanding of the context.
:param attention_mask: Optional attention mask to avoid attending to padding
tokens. Shape: (batch_size, sequence_length) for 2D or (batch_size, 1,
sequence_length, sequence_length) for 4D causal mask.
:param position_ids: Optional position indices for tokens in the sequence.
Shape: (batch_size, sequence_length). If None, auto-generated based on
sequence length and past key values.
:param past_key_values: Optional cached key-value states from previous forward
passes for efficient generation. Tuple of layer key-value pairs.
:param use_cache: Whether to return key-value states for caching in subsequent
forward passes. Useful for autoregressive generation efficiency.
:param output_attentions: Whether to return attention weights from the
transformer layer. Used for analysis and visualization.
:param output_hidden_states: Whether to return hidden states from the
transformer layer. Currently not implemented in this model.
:param return_dict: Whether to return structured CausalLMOutputWithPast instead
of raw logits. If None, uses config.use_return_dict default.
:return: Either raw logits tensor (batch_size, sequence_length, vocab_size) if
return_dict=False, or CausalLMOutputWithPast containing logits, past key
values, and optional attention weights.
:raises ValueError: If verifier components (embed_tokens, rotary_emb, lm_head)
are not attached. Call attach_verifier() before using forward().
"""
if self.embed_tokens is None or self.rotary_emb is None or self.lm_head is None:
raise ValueError(
"Verifier model layers not initialized. "
"Call `attach_verifier` to set up the model before using forward."
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
inputs_embeds = self.embed_tokens(input_ids)
if self.embedding_layernorm is not None:
inputs_embeds = self.embedding_layernorm(inputs_embeds)
hidden_states = self.fusion_fc(
torch.cat([inputs_embeds, hidden_states], dim=-1)
)
hidden_states, attention_mask, position_ids = self._prepare_decoder_inputs(
hidden_states, attention_mask, position_ids, past_key_values
)
cos, sin = self.rotary_emb(hidden_states, position_ids)
layer_outputs = self.transformer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values[0] if past_key_values else None,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=(cos, sin),
)
hidden_states = layer_outputs[0]
if self.pre_lm_head_layernorm is not None:
hidden_states = self.pre_lm_head_layernorm(hidden_states)
logits = self.lm_head(hidden_states)
if not return_dict:
return logits
return CausalLMOutputWithPast(
logits=logits,
past_key_values=layer_outputs[1] if use_cache else None,
hidden_states=None,
attentions=None,
)
def _prepare_decoder_inputs(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]],
) -> tuple[torch.FloatTensor, Optional[torch.Tensor], Optional[torch.LongTensor]]:
batch_size, seq_length = hidden_states.shape[:2]
if position_ids is None:
device = hidden_states.device
position_ids = (
torch.arange(seq_length, dtype=torch.long, device=device) # type: ignore[assignment]
.unsqueeze(0)
.expand(batch_size, -1)
)
if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values else 0
)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=getattr(self.config, "sliding_window", None),
)
return hidden_states, attention_mask, position_ids
def _create_layernorm(self) -> Optional[nn.Module]:
if not self.config.layernorms:
return None
return self._layernorm_class()(
self.hidden_size, eps=self.config.transformer_layer_config.rms_norm_eps
)
def _create_transformer_layer(self) -> nn.Module:
layer_class = self._transformer_layer_class()
layer = layer_class(
self.config.transformer_layer_config,
layer_idx=0,
)
if not self.config.layernorms:
# Replace input_layernorm with Identity if layernorms are not used
layer.input_layernorm = nn.Identity()
return layer
def _layernorm_class(self) -> type[nn.Module]:
return LlamaRMSNorm
def _transformer_layer_class(self) -> type[nn.Module]:
return LlamaDecoderLayer