RelaxingSnorlax commited on
Commit
aa996f9
·
verified ·
1 Parent(s): ac37b0d

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +51 -0
  2. config.json +54 -0
  3. eagle3.py +543 -0
  4. generation_config.json +4 -0
  5. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eagle-3 Speculator for Llama-3.1-8B-Instruct
2
+
3
+ This is an Eagle-3 speculator checkpoint converted to the [speculators](https://github.com/neuralmagic/speculators) format.
4
+
5
+ ## Model Details
6
+
7
+ - **Base Model**: meta-llama/Meta-Llama-3.1-8B-Instruct
8
+ - **Speculator Type**: Eagle-3
9
+ - **Draft Vocabulary Size**: 32,000
10
+ - **Target Vocabulary Size**: 128,256
11
+ - **Architecture**: Single-layer transformer with vocabulary mapping
12
+
13
+ ## Key Features
14
+
15
+ - **Vocabulary Mapping**: Maps between draft (32K) and target (128K) vocabularies
16
+ - **Custom Attention**: Modified attention layer accepting 2×hidden_size input
17
+ - **Fusion Layer**: Processes 3 verifier layers (3×4096 → 4096)
18
+ - **Layer Normalization**: Applied before residual connection (HF checkpoint style)
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig
24
+ from transformers import AutoModelForCausalLM
25
+
26
+ # Load verifier model
27
+ verifier = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
28
+
29
+ # Load Eagle-3 speculator
30
+ speculator = Eagle3Speculator.from_pretrained(
31
+ "nm-testing/eagle3-llama3.1-8b-instruct-speculators",
32
+ verifier=verifier
33
+ )
34
+ ```
35
+
36
+ ## Configuration
37
+
38
+ This model uses the Eagle-3 architecture with:
39
+ - Hidden size: 4096
40
+ - Attention heads: 32
41
+ - Key-value heads: 8
42
+ - Intermediate size: 14336
43
+ - RMS norm epsilon: 1e-05
44
+
45
+ ## Citation
46
+
47
+ Based on the Eagle-3 paper: https://arxiv.org/abs/2503.01840
48
+
49
+ ## License
50
+
51
+ Please refer to the base Llama-3.1 model license.
config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Eagle3Speculator"
4
+ ],
5
+ "auto_map": {
6
+ "": "eagle3.Eagle3SpeculatorConfig"
7
+ },
8
+ "draft_vocab_size": 32000,
9
+ "has_no_defaults_at_init": false,
10
+ "norm_before_residual": true,
11
+ "speculators_config": {
12
+ "algorithm": "eagle3",
13
+ "default_proposal_method": "greedy",
14
+ "proposal_methods": [
15
+ {
16
+ "accept_tolerance": 0.0,
17
+ "proposal_type": "greedy",
18
+ "speculative_tokens": 5,
19
+ "verifier_accept_k": 1
20
+ }
21
+ ],
22
+ "verifier": {
23
+ "architectures": [
24
+ "LlamaForCausalLM"
25
+ ],
26
+ "name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct"
27
+ }
28
+ },
29
+ "speculators_model_type": "eagle3",
30
+ "speculators_version": "0.1.0.dev13",
31
+ "torch_dtype": "float32",
32
+ "transformer_layer_config": {
33
+ "attention_bias": false,
34
+ "attention_dropout": 0.0,
35
+ "head_dim": 128,
36
+ "hidden_act": "silu",
37
+ "hidden_size": 4096,
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 14336,
40
+ "max_position_embeddings": 131072,
41
+ "mlp_bias": false,
42
+ "model_type": "llama",
43
+ "num_attention_heads": 32,
44
+ "num_hidden_layers": 1,
45
+ "num_key_value_heads": 8,
46
+ "pretraining_tp": 1,
47
+ "rms_norm_eps": 1e-05,
48
+ "rope_scaling": null,
49
+ "rope_theta": 500000.0,
50
+ "use_cache": true,
51
+ "vocab_size": 128256
52
+ },
53
+ "transformers_version": "4.52.4"
54
+ }
eagle3.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculators implementation of EAGLE-3:
3
+ - https://arxiv.org/abs/2503.01840
4
+
5
+ Classes:
6
+ Eagle3SpeculatorConfig: Configuration class for EAGLE-3 speculator model
7
+ EagleSpeculator3: Main model implementation for EAGLE-3 speculators
8
+ Eagle3Attention: Custom attention layer for EAGLE-3, processes
9
+ concatenated embeddings and hidden states
10
+ Eagle3DecoderLayer: Custom decoder layer for EAGLE-3, processes
11
+ concatenated embeddings and hidden states with Eagle3Attention
12
+ and support for moving hidden layernorm before residual
13
+ """
14
+
15
+ import os
16
+ from typing import Any, ClassVar, Literal, Optional, Union
17
+
18
+ import torch
19
+ from pydantic import Field, field_serializer, field_validator
20
+ from torch import nn
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from transformers.models.llama.configuration_llama import LlamaConfig
25
+ from transformers.models.llama.modeling_llama import (
26
+ LlamaMLP,
27
+ LlamaRMSNorm,
28
+ apply_rotary_pos_emb,
29
+ repeat_kv,
30
+ )
31
+
32
+ from speculators import SpeculatorModel, SpeculatorModelConfig
33
+
34
+ __all__ = [
35
+ "Eagle3Attention",
36
+ "Eagle3DecoderLayer",
37
+ "Eagle3Speculator",
38
+ "Eagle3SpeculatorConfig",
39
+ ]
40
+
41
+
42
+ @SpeculatorModelConfig.register("eagle3")
43
+ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
44
+ """
45
+ Configuration for EAGLE-3 speculator with vocabulary mapping.
46
+
47
+ EAGLE-3 features vocabulary mapping between draft (32K) and target (128K)
48
+ vocabularies, enabling cross-tokenizer speculation.
49
+
50
+ :param transformer_layer_config: Configuration for the transformer decoder layer
51
+ :param draft_vocab_size: Size of draft model vocabulary for speculation
52
+ :param norm_before_residual: Apply hidden_norm before storing residual
53
+ """
54
+
55
+ speculators_model_type: Literal["eagle3"] = "eagle3"
56
+ architectures: list[str] = Field(
57
+ default_factory=lambda: ["Eagle3Speculator"],
58
+ description="Model architectures that can load these weights",
59
+ )
60
+
61
+ transformer_layer_config: PretrainedConfig = Field(
62
+ default_factory=LlamaConfig,
63
+ description="Configuration for the transformer decoder layer",
64
+ )
65
+
66
+ draft_vocab_size: int = Field(
67
+ default=32000,
68
+ description="Size of draft model vocabulary for speculation",
69
+ )
70
+
71
+ norm_before_residual: bool = Field(
72
+ default=False,
73
+ description="Apply hidden_norm before storing residual",
74
+ )
75
+
76
+ @property
77
+ def target_vocab_size(self) -> int:
78
+ """Get target vocabulary size from transformer config."""
79
+ return self.transformer_layer_config.vocab_size
80
+
81
+ @field_serializer("transformer_layer_config")
82
+ def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
83
+ """Serialize transformer config to dict."""
84
+ return value.to_diff_dict()
85
+
86
+ @field_validator("transformer_layer_config", mode="before")
87
+ @classmethod
88
+ def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
89
+ """Validate and convert transformer config."""
90
+ if isinstance(value, dict):
91
+ config_class: type[PretrainedConfig] = LlamaConfig
92
+ if "model_type" in value:
93
+ from transformers import AutoConfig
94
+
95
+ config_class = AutoConfig.for_model(
96
+ model_type=value["model_type"]
97
+ ).__class__
98
+ return config_class(**value)
99
+ return value
100
+
101
+
102
+ class Eagle3Attention(nn.Module):
103
+ """
104
+ Eagle-3 attention module that processes concatenated embeddings and hidden states.
105
+
106
+ Modified from standard Llama attention to accept 2x hidden_size input
107
+ for Q/K/V projections while maintaining standard output size.
108
+ """
109
+
110
+ def __init__(self, config: PretrainedConfig, layer_idx: int):
111
+ super().__init__()
112
+ self.config = config
113
+ self.layer_idx = layer_idx
114
+
115
+ self.num_heads = config.num_attention_heads
116
+ self.num_key_value_heads = config.num_key_value_heads
117
+ self.hidden_size = config.hidden_size
118
+ self.head_dim = self.hidden_size // self.num_heads
119
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
120
+
121
+ input_size = 2 * self.hidden_size
122
+ self.q_proj = nn.Linear(
123
+ input_size, self.num_heads * self.head_dim, bias=config.attention_bias
124
+ )
125
+ self.k_proj = nn.Linear(
126
+ input_size,
127
+ self.num_key_value_heads * self.head_dim,
128
+ bias=config.attention_bias,
129
+ )
130
+ self.v_proj = nn.Linear(
131
+ input_size,
132
+ self.num_key_value_heads * self.head_dim,
133
+ bias=config.attention_bias,
134
+ )
135
+ self.o_proj = nn.Linear(
136
+ self.hidden_size, self.hidden_size, bias=config.attention_bias
137
+ )
138
+
139
+ def forward(
140
+ self,
141
+ hidden_states: torch.Tensor,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ position_ids: Optional[torch.LongTensor] = None,
144
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
145
+ output_attentions: bool = False,
146
+ use_cache: bool = False,
147
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
148
+ **kwargs, # noqa: ARG002
149
+ ) -> tuple:
150
+ """
151
+ Forward pass for Eagle-3 attention.
152
+ Taken from Llama Attention but modified to accept 2x hidden_size input.
153
+
154
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
155
+ :param attention_mask: Optional attention mask
156
+ :param position_ids: Optional position IDs for rotary embeddings
157
+ :param past_key_value: Optional cached key-value pairs
158
+ :param output_attentions: Whether to return attention weights
159
+ :param use_cache: Whether to cache key-value pairs
160
+ :param position_embeddings: Optional precomputed rotary embeddings
161
+ :return: Tuple of (hidden_states, [attention_weights], [past_key_value])
162
+ """
163
+ bsz, q_len, _ = hidden_states.size()
164
+
165
+ query_states = self.q_proj(hidden_states)
166
+ key_states = self.k_proj(hidden_states)
167
+ value_states = self.v_proj(hidden_states)
168
+
169
+ query_states = query_states.view(
170
+ bsz, q_len, self.num_heads, self.head_dim
171
+ ).transpose(1, 2)
172
+ key_states = key_states.view(
173
+ bsz, q_len, self.num_key_value_heads, self.head_dim
174
+ ).transpose(1, 2)
175
+ value_states = value_states.view(
176
+ bsz, q_len, self.num_key_value_heads, self.head_dim
177
+ ).transpose(1, 2)
178
+
179
+ if position_embeddings is not None:
180
+ cos, sin = position_embeddings
181
+ query_states, key_states = apply_rotary_pos_emb(
182
+ query_states, key_states, cos, sin, position_ids
183
+ )
184
+
185
+ past_key_value_out = None
186
+ if past_key_value is not None:
187
+ past_key = past_key_value[0]
188
+ past_value = past_key_value[1]
189
+ key_states = torch.cat([past_key, key_states], dim=2)
190
+ value_states = torch.cat([past_value, value_states], dim=2)
191
+
192
+ if use_cache:
193
+ past_key_value_out = (key_states, value_states)
194
+
195
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
196
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
197
+
198
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / (
199
+ self.head_dim**0.5
200
+ )
201
+
202
+ if attention_mask is not None:
203
+ attn_weights = attn_weights + attention_mask
204
+
205
+ attn_weights = nn.functional.softmax(
206
+ attn_weights, dim=-1, dtype=torch.float32
207
+ ).to(query_states.dtype)
208
+
209
+ attn_output = torch.matmul(attn_weights, value_states)
210
+ attn_output = attn_output.transpose(1, 2).contiguous()
211
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
212
+
213
+ attn_output = self.o_proj(attn_output)
214
+
215
+ if not output_attentions:
216
+ attn_weights = None
217
+
218
+ return attn_output, attn_weights, past_key_value_out
219
+
220
+
221
+ class Eagle3DecoderLayer(nn.Module):
222
+ """
223
+ Eagle-3 decoder layer that processes concatenated embeddings and hidden states.
224
+
225
+ Accepts 2x hidden_size input from concatenated embeddings and fused hidden states.
226
+ Uses Eagle3Attention for the self-attention computation.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ config: PretrainedConfig,
232
+ layer_idx: int,
233
+ norm_before_residual: bool = False,
234
+ ):
235
+ super().__init__()
236
+ self.hidden_size = config.hidden_size
237
+ self.norm_before_residual = norm_before_residual
238
+
239
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
241
+ self.post_attention_layernorm = LlamaRMSNorm(
242
+ config.hidden_size, eps=config.rms_norm_eps
243
+ )
244
+
245
+ self.self_attn = Eagle3Attention(config, layer_idx)
246
+
247
+ self.mlp = LlamaMLP(config)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
255
+ output_attentions: Optional[bool] = False,
256
+ use_cache: Optional[bool] = False,
257
+ cache_position: Optional[torch.LongTensor] = None, # noqa: ARG002
258
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
259
+ **kwargs, # noqa: ARG002
260
+ ) -> tuple:
261
+ """
262
+ Process concatenated embeddings and hidden states through modified decoder
263
+ layer.
264
+
265
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
266
+ :return: Tuple of layer outputs
267
+ """
268
+ embeds = hidden_states[:, :, : self.hidden_size]
269
+ hidden = hidden_states[:, :, self.hidden_size : 2 * self.hidden_size]
270
+
271
+ if self.norm_before_residual:
272
+ hidden = self.hidden_norm(hidden)
273
+ residual = hidden
274
+ else:
275
+ residual = hidden
276
+ hidden = self.hidden_norm(hidden)
277
+
278
+ embeds = self.input_layernorm(embeds)
279
+
280
+ attn_input = torch.cat([embeds, hidden], dim=-1)
281
+
282
+ attn_output, attn_weights, past_key_value_out = self.self_attn(
283
+ hidden_states=attn_input,
284
+ attention_mask=attention_mask,
285
+ position_ids=position_ids,
286
+ past_key_value=past_key_value,
287
+ output_attentions=output_attentions,
288
+ use_cache=use_cache,
289
+ position_embeddings=position_embeddings,
290
+ )
291
+
292
+ hidden_states = residual + attn_output
293
+
294
+ residual = hidden_states
295
+ hidden_states = self.post_attention_layernorm(hidden_states)
296
+ hidden_states = self.mlp(hidden_states)
297
+ hidden_states = residual + hidden_states
298
+
299
+ outputs = (hidden_states,)
300
+
301
+ if output_attentions:
302
+ outputs += (attn_weights,) # type: ignore[assignment]
303
+
304
+ if use_cache:
305
+ outputs += (past_key_value_out,) # type: ignore[assignment]
306
+
307
+ return outputs
308
+
309
+
310
+ @SpeculatorModel.register("eagle3")
311
+ class Eagle3Speculator(SpeculatorModel):
312
+ """
313
+ EAGLE-3 speculator with vocabulary mapping and multi-layer fusion.
314
+
315
+ EAGLE-3 processes concatenated hidden states from multiple verifier layers
316
+ through a fusion layer, then combines with embeddings for a custom decoder
317
+ layer that accepts 2x hidden_size input.
318
+ """
319
+
320
+ config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
321
+ _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
322
+ "verifier*",
323
+ ]
324
+ _keys_to_ignore_on_save: ClassVar[list[str]] = [] # type: ignore[misc,assignment]
325
+
326
+ def __init__(
327
+ self,
328
+ config: Eagle3SpeculatorConfig,
329
+ verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
330
+ verifier_attachment_mode: Optional[
331
+ Literal["detached", "full", "train_only"]
332
+ ] = None,
333
+ ):
334
+ """
335
+ Initialize Eagle3 speculator.
336
+
337
+ :param config: Eagle3SpeculatorConfig instance
338
+ :param verifier: Optional verifier model
339
+ :param verifier_attachment_mode: How to attach the verifier
340
+ """
341
+ if not isinstance(config, Eagle3SpeculatorConfig):
342
+ raise ValueError(
343
+ f"config must be Eagle3SpeculatorConfig, got {type(config)}"
344
+ )
345
+
346
+ self.config: Eagle3SpeculatorConfig = config
347
+
348
+ self.hidden_size = config.transformer_layer_config.hidden_size
349
+ self.draft_vocab_size = config.draft_vocab_size
350
+ self.target_vocab_size = config.target_vocab_size
351
+
352
+ super().__init__(
353
+ config=config,
354
+ verifier=verifier,
355
+ verifier_attachment_mode=verifier_attachment_mode,
356
+ )
357
+
358
+ self.embed_tokens = nn.Embedding(
359
+ self.target_vocab_size,
360
+ self.hidden_size,
361
+ padding_idx=config.transformer_layer_config.pad_token_id
362
+ if hasattr(config.transformer_layer_config, "pad_token_id")
363
+ else None,
364
+ )
365
+
366
+ self.fc = nn.Linear(
367
+ 3 * self.hidden_size,
368
+ self.hidden_size,
369
+ bias=False,
370
+ )
371
+
372
+ self.layers = nn.ModuleList(
373
+ [
374
+ Eagle3DecoderLayer(
375
+ config.transformer_layer_config,
376
+ layer_idx=0,
377
+ norm_before_residual=config.norm_before_residual,
378
+ )
379
+ ]
380
+ )
381
+
382
+ self.norm = LlamaRMSNorm(
383
+ self.hidden_size,
384
+ eps=config.transformer_layer_config.rms_norm_eps,
385
+ )
386
+
387
+ self.lm_head = nn.Linear(
388
+ self.hidden_size,
389
+ self.draft_vocab_size,
390
+ bias=False,
391
+ )
392
+
393
+ self.register_buffer(
394
+ "d2t",
395
+ torch.zeros(self.draft_vocab_size, dtype=torch.long),
396
+ )
397
+ self.register_buffer(
398
+ "t2d",
399
+ torch.zeros(self.target_vocab_size, dtype=torch.bool),
400
+ )
401
+
402
+ # Type hints for buffers
403
+ self.d2t: torch.Tensor
404
+ self.t2d: torch.Tensor
405
+
406
+ self.post_init()
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor,
411
+ hidden_states: torch.FloatTensor,
412
+ attention_mask: Optional[torch.Tensor] = None,
413
+ position_ids: Optional[torch.LongTensor] = None,
414
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
415
+ use_cache: Optional[bool] = None,
416
+ output_attentions: Optional[bool] = None,
417
+ output_hidden_states: Optional[bool] = None, # noqa: ARG002
418
+ return_dict: Optional[bool] = None,
419
+ ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
420
+ """
421
+ Forward pass for EAGLE-3 speculation.
422
+
423
+ :param input_ids: Input token IDs from draft vocabulary
424
+ :param hidden_states: Concatenated hidden states from 3 verifier layers
425
+ [B, L, 3*H]
426
+ :param attention_mask: Optional attention mask
427
+ :param position_ids: Optional position IDs
428
+ :param past_key_values: Optional cached key-values
429
+ :param use_cache: Whether to cache key-values
430
+ :param output_attentions: Return attention weights
431
+ :param output_hidden_states: Return hidden states
432
+ :param return_dict: Return dict output
433
+ :return: Model outputs with draft vocabulary logits
434
+ """
435
+ return_dict = (
436
+ return_dict if return_dict is not None else self.config.use_return_dict
437
+ )
438
+
439
+ inputs_embeds = self.embed_tokens(input_ids)
440
+
441
+ fused_hidden = self.fc(hidden_states)
442
+
443
+ layer_input = torch.cat([inputs_embeds, fused_hidden], dim=-1)
444
+
445
+ batch_size, seq_length = layer_input.shape[:2]
446
+ if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
447
+ past_key_values_length = (
448
+ past_key_values[0][0].shape[2] if past_key_values else 0
449
+ )
450
+ attention_mask = _prepare_4d_causal_attention_mask(
451
+ attention_mask,
452
+ (batch_size, seq_length),
453
+ hidden_states,
454
+ past_key_values_length,
455
+ )
456
+
457
+ if position_ids is None:
458
+ device = hidden_states.device
459
+ position_ids = (
460
+ torch.arange( # type: ignore[assignment]
461
+ seq_length, dtype=torch.long, device=device
462
+ )
463
+ .unsqueeze(0)
464
+ .expand(batch_size, -1)
465
+ )
466
+
467
+ layer_outputs = self.layers[0](
468
+ layer_input,
469
+ attention_mask=attention_mask,
470
+ position_ids=position_ids,
471
+ past_key_value=past_key_values[0] if past_key_values else None,
472
+ output_attentions=output_attentions,
473
+ use_cache=use_cache,
474
+ )
475
+
476
+ hidden_states = layer_outputs[0]
477
+
478
+ hidden_states = self.norm(hidden_states)
479
+
480
+ logits = self.compute_logits(hidden_states, map_to_target_vocab=True)
481
+
482
+ if not return_dict:
483
+ return logits
484
+
485
+ return CausalLMOutputWithPast(
486
+ logits=logits,
487
+ past_key_values=[layer_outputs[1]] if use_cache else None, # type: ignore[arg-type]
488
+ hidden_states=None,
489
+ attentions=None,
490
+ )
491
+
492
+ def compute_logits(
493
+ self,
494
+ hidden_states: torch.FloatTensor,
495
+ map_to_target_vocab: bool = True,
496
+ ) -> torch.FloatTensor:
497
+ """
498
+ Compute logits with optional vocabulary mapping.
499
+
500
+ :param hidden_states: Hidden states from the model
501
+ :param map_to_target_vocab: Whether to map draft logits to target vocabulary
502
+ :return: Logits tensor
503
+ """
504
+ logits = self.lm_head(hidden_states)
505
+
506
+ if not map_to_target_vocab:
507
+ return logits
508
+
509
+ batch_size, seq_length, _ = logits.shape
510
+
511
+ draft_indices = torch.arange(self.draft_vocab_size, device=logits.device)
512
+
513
+ target_indices = draft_indices + self.d2t
514
+
515
+ mapped_logits = logits.new_full(
516
+ (batch_size, seq_length, self.target_vocab_size), float("-inf")
517
+ )
518
+
519
+ mapped_logits[:, :, target_indices] = logits
520
+
521
+ return mapped_logits
522
+
523
+ def map_draft_to_target_tokens(
524
+ self, draft_tokens: torch.LongTensor
525
+ ) -> torch.LongTensor:
526
+ """
527
+ Map draft token IDs to target token IDs.
528
+
529
+ :param draft_tokens: Draft vocabulary token IDs
530
+ :return: Target vocabulary token IDs
531
+ """
532
+ return draft_tokens + self.d2t[draft_tokens] # type: ignore[return-value]
533
+
534
+ def check_target_token_availability(
535
+ self, target_tokens: torch.LongTensor
536
+ ) -> torch.BoolTensor:
537
+ """
538
+ Check if target tokens have draft equivalents.
539
+
540
+ :param target_tokens: Target vocabulary token IDs
541
+ :return: Boolean mask indicating availability in draft vocabulary
542
+ """
543
+ return self.t2d[target_tokens] # type: ignore[return-value]
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.52.4"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba637bf5e55fed7ab4d59ca514f0e1052d62229945a83be0b619bb9860426a42
3
+ size 3800490840