Upload 4 files
Browse files- configuration_rwkv_hybrid.py +8 -6
- modeling_rwkv_hybrid.py +64 -147
- wkv.py +1 -2
configuration_rwkv_hybrid.py
CHANGED
@@ -15,9 +15,9 @@
|
|
15 |
# limitations under the License.
|
16 |
"""RwkvHybrid model configuration"""
|
17 |
|
18 |
-
from
|
19 |
-
from
|
20 |
-
from
|
21 |
from typing import Optional, Union, List
|
22 |
|
23 |
|
@@ -218,15 +218,17 @@ class RwkvHybridConfig(PretrainedConfig):
|
|
218 |
raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
|
219 |
wkv_version must be 6 or 7")
|
220 |
|
221 |
-
if wkv_layers == "full" or wkv_layers
|
222 |
self.wkv_layers = list(range(num_hidden_layers))
|
223 |
elif isinstance(wkv_layers, list):
|
224 |
if all(isinstance(layer, int) for layer in wkv_layers):
|
225 |
self.wkv_layers = wkv_layers
|
226 |
else:
|
227 |
-
raise ValueError(
|
|
|
228 |
else:
|
229 |
-
raise TypeError(
|
|
|
230 |
|
231 |
# for backward compatibility
|
232 |
if num_key_value_heads is None:
|
|
|
15 |
# limitations under the License.
|
16 |
"""RwkvHybrid model configuration"""
|
17 |
|
18 |
+
from ...configuration_utils import PretrainedConfig
|
19 |
+
from ...modeling_rope_utils import rope_config_validation
|
20 |
+
from ...utils import logging
|
21 |
from typing import Optional, Union, List
|
22 |
|
23 |
|
|
|
218 |
raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
|
219 |
wkv_version must be 6 or 7")
|
220 |
|
221 |
+
if wkv_layers == "full" or wkv_layers is None:
|
222 |
self.wkv_layers = list(range(num_hidden_layers))
|
223 |
elif isinstance(wkv_layers, list):
|
224 |
if all(isinstance(layer, int) for layer in wkv_layers):
|
225 |
self.wkv_layers = wkv_layers
|
226 |
else:
|
227 |
+
raise ValueError(
|
228 |
+
"All elements in wkv_layers must be integers.")
|
229 |
else:
|
230 |
+
raise TypeError(
|
231 |
+
"wkv_layers must be either 'full', None, or a list of integers.")
|
232 |
|
233 |
# for backward compatibility
|
234 |
if num_key_value_heads is None:
|
modeling_rwkv_hybrid.py
CHANGED
@@ -1,22 +1,23 @@
|
|
1 |
-
from typing import List, Optional, Tuple, Union
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from transformers.cache_utils import Cache
|
6 |
|
7 |
-
from
|
|
|
8 |
from .hybrid_cache import HybridCache
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
|
14 |
-
from
|
15 |
BaseModelOutputWithPast,
|
16 |
CausalLMOutputWithPast,
|
17 |
)
|
18 |
-
from
|
19 |
-
from
|
20 |
LossKwargs,
|
21 |
add_start_docstrings,
|
22 |
add_start_docstrings_to_model_forward,
|
@@ -27,104 +28,72 @@ import threading
|
|
27 |
from .wkv import Rwkv7Attention, Rwkv6Attention
|
28 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
29 |
|
30 |
-
from
|
31 |
-
Qwen2RMSNorm,
|
32 |
-
Qwen2RotaryEmbedding,
|
33 |
Qwen2Attention)
|
34 |
|
35 |
logger = logging.get_logger(__name__)
|
36 |
|
37 |
_CONFIG_FOR_DOC = "RwkvHybridConfig"
|
38 |
|
39 |
-
|
40 |
class RwkvHybridDecoderLayer(nn.Module):
|
41 |
-
def __init__(self, config: RwkvHybridConfig, layer_idx: int):
|
42 |
super().__init__()
|
43 |
self.hidden_size = config.hidden_size
|
44 |
|
45 |
self.is_rwkv = True if layer_idx in config.wkv_layers else False
|
46 |
if self.is_rwkv:
|
47 |
if config.wkv_version == 7:
|
48 |
-
self.self_attn = Rwkv7Attention(
|
49 |
-
|
|
|
50 |
elif config.wkv_version == 6:
|
51 |
-
self.self_attn = Rwkv6Attention(
|
52 |
-
|
|
|
53 |
else:
|
54 |
raise NotImplementedError
|
55 |
-
|
56 |
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
|
|
|
|
|
|
|
57 |
|
58 |
self.mlp = Qwen2MLP(config)
|
59 |
self.input_layernorm = Qwen2RMSNorm(
|
60 |
config.hidden_size, eps=config.rms_norm_eps)
|
61 |
self.post_attention_layernorm = Qwen2RMSNorm(
|
62 |
-
config.hidden_size, eps=config.rms_norm_eps)
|
63 |
-
self.layer_idx = layer_idx
|
64 |
|
|
|
65 |
def forward(
|
66 |
self,
|
67 |
hidden_states: torch.Tensor,
|
68 |
attention_mask: Optional[torch.Tensor] = None,
|
69 |
-
position_ids: Optional[torch.
|
70 |
past_key_value: Optional[Cache] = None,
|
71 |
output_attentions: Optional[bool] = False,
|
72 |
use_cache: Optional[bool] = False,
|
73 |
-
cache_position: Optional[torch.
|
74 |
-
position_embeddings: Optional[torch.Tensor] = None,
|
75 |
-
sequence_mask: Optional[torch.Tensor] = None,
|
76 |
-
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
77 |
-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
78 |
-
max_length_q: Optional[int] = None,
|
79 |
-
max_length_k: Optional[int] = None,
|
80 |
-
cu_seqlens: Optional[torch.LongTensor] = None,
|
81 |
-
v_first: Optional[torch.LongTensor] = None,
|
82 |
**kwargs,
|
83 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
84 |
-
|
85 |
-
if sequence_mask is not None:
|
86 |
-
assert len(sequence_mask.shape) == 2, (
|
87 |
-
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
88 |
-
"for padding purposes (0 indicating padding). "
|
89 |
-
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
90 |
-
)
|
91 |
-
hidden_states = hidden_states.mul(
|
92 |
-
sequence_mask[:, -hidden_states.shape[-2]:, None])
|
93 |
-
|
94 |
residual = hidden_states
|
95 |
|
96 |
hidden_states = self.input_layernorm(hidden_states)
|
97 |
|
98 |
# RWKV attention
|
99 |
-
|
100 |
-
hidden_states,
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
v_first=v_first,
|
110 |
-
**kwargs
|
111 |
-
)
|
112 |
-
else:
|
113 |
-
hidden_states, self_attn_weights = self.self_attn(
|
114 |
-
hidden_states=hidden_states,
|
115 |
-
attention_mask=attention_mask,
|
116 |
-
position_ids=position_ids,
|
117 |
-
past_key_value=past_key_value,
|
118 |
-
output_attentions=output_attentions,
|
119 |
-
use_cache=use_cache,
|
120 |
-
cache_position=cache_position,
|
121 |
-
position_embeddings=position_embeddings,
|
122 |
-
cu_seq_lens_q=cu_seq_lens_q,
|
123 |
-
cu_seq_lens_k=cu_seq_lens_k,
|
124 |
-
max_length_q=max_length_q,
|
125 |
-
max_length_k=max_length_k,
|
126 |
-
**kwargs
|
127 |
-
)
|
128 |
hidden_states = residual + hidden_states
|
129 |
|
130 |
# Fully Connected
|
@@ -137,12 +106,8 @@ class RwkvHybridDecoderLayer(nn.Module):
|
|
137 |
if output_attentions:
|
138 |
outputs += (self_attn_weights,)
|
139 |
|
140 |
-
if self.is_rwkv:
|
141 |
-
outputs += (v_first,)
|
142 |
-
|
143 |
return outputs
|
144 |
|
145 |
-
|
146 |
RWKV_HYBRID_START_DOCSTRING = r"""
|
147 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
148 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
@@ -159,7 +124,6 @@ RWKV_HYBRID_START_DOCSTRING = r"""
|
|
159 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
160 |
"""
|
161 |
|
162 |
-
|
163 |
@add_start_docstrings(
|
164 |
"The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
|
165 |
RWKV_HYBRID_START_DOCSTRING,
|
@@ -182,7 +146,6 @@ class RwkvHybridPreTrainedModel(PreTrainedModel):
|
|
182 |
if module.padding_idx is not None:
|
183 |
module.weight.data[module.padding_idx].zero_()
|
184 |
|
185 |
-
|
186 |
RWKV_HYBRID_INPUTS_DOCSTRING = r"""
|
187 |
Args:
|
188 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
@@ -275,13 +238,11 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
275 |
self.padding_idx = config.pad_token_id
|
276 |
self.vocab_size = config.vocab_size
|
277 |
|
278 |
-
self.embed_tokens = nn.Embedding(
|
279 |
-
config.vocab_size, config.hidden_size, self.padding_idx)
|
280 |
self.thread_local = threading.local()
|
281 |
self.thread_local.v_first = None
|
282 |
self.layers = nn.ModuleList(
|
283 |
-
[RwkvHybridDecoderLayer(config, layer_idx)
|
284 |
-
for layer_idx in range(config.num_hidden_layers)]
|
285 |
)
|
286 |
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
287 |
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
@@ -305,20 +266,19 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
305 |
for layer in self.layers:
|
306 |
layer.self_attn.time_mixer.post_init()
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
def get_input_embeddings(self):
|
309 |
return self.embed_tokens
|
310 |
|
311 |
def set_input_embeddings(self, value):
|
312 |
self.embed_tokens = value
|
313 |
|
314 |
-
def get_v_first(self, layer_idx: int, use_cache: bool, past_key_value: HybridCache):
|
315 |
-
if layer_idx == 0:
|
316 |
-
return None
|
317 |
-
|
318 |
-
if use_cache:
|
319 |
-
return past_key_value.get_v_first()
|
320 |
-
return self.v_first
|
321 |
-
|
322 |
@add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
|
323 |
def forward(
|
324 |
self,
|
@@ -332,12 +292,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
332 |
output_hidden_states: Optional[bool] = None,
|
333 |
return_dict: Optional[bool] = None,
|
334 |
cache_position: Optional[torch.LongTensor] = None,
|
335 |
-
|
336 |
-
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
337 |
-
max_length_q: Optional[int] = None,
|
338 |
-
max_length_k: Optional[int] = None,
|
339 |
-
cu_seqlens: Optional[torch.LongTensor] = None,
|
340 |
-
**kwargs,
|
341 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
342 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
343 |
output_hidden_states = (
|
@@ -347,8 +302,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
347 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
348 |
|
349 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
350 |
-
raise ValueError(
|
351 |
-
"You must specify exactly one of input_ids or inputs_embeds")
|
352 |
|
353 |
if self.gradient_checkpointing and self.training and use_cache:
|
354 |
logger.warning_once(
|
@@ -363,8 +317,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
363 |
past_key_values = HybridCache()
|
364 |
|
365 |
if cache_position is None:
|
366 |
-
past_seen_tokens = past_key_values.get_seq_length(
|
367 |
-
) if past_key_values is not None else 0
|
368 |
cache_position = torch.arange(
|
369 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
370 |
)
|
@@ -386,7 +339,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
386 |
all_self_attns = () if output_attentions else None
|
387 |
|
388 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
389 |
-
first_rwkv_layer = True
|
390 |
if output_hidden_states:
|
391 |
all_hidden_states += (hidden_states,)
|
392 |
|
@@ -401,14 +353,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
401 |
use_cache,
|
402 |
cache_position,
|
403 |
position_embeddings,
|
404 |
-
attention_mask,
|
405 |
-
cu_seq_lens_q,
|
406 |
-
cu_seq_lens_k,
|
407 |
-
max_length_q,
|
408 |
-
max_length_k,
|
409 |
-
cu_seqlens,
|
410 |
-
self.get_v_first(decoder_layer.layer_idx,
|
411 |
-
use_cache, past_key_values)
|
412 |
)
|
413 |
else:
|
414 |
layer_outputs = decoder_layer(
|
@@ -420,14 +364,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
420 |
use_cache=use_cache,
|
421 |
cache_position=cache_position,
|
422 |
position_embeddings=position_embeddings,
|
423 |
-
|
424 |
-
cu_seq_lens_q=cu_seq_lens_q,
|
425 |
-
cu_seq_lens_k=cu_seq_lens_k,
|
426 |
-
max_length_q=max_length_q,
|
427 |
-
max_length_k=max_length_k,
|
428 |
-
cu_seqlens=cu_seqlens,
|
429 |
-
v_first=self.get_v_first(
|
430 |
-
decoder_layer.layer_idx, use_cache, past_key_values)
|
431 |
)
|
432 |
|
433 |
hidden_states = layer_outputs[0]
|
@@ -435,14 +372,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
435 |
if output_attentions:
|
436 |
all_self_attns += (layer_outputs[1],)
|
437 |
|
438 |
-
if first_rwkv_layer is True and decoder_layer.is_rwkv:
|
439 |
-
v_first = layer_outputs[-1]
|
440 |
-
if use_cache:
|
441 |
-
past_key_values.update_v_first(v_first)
|
442 |
-
else:
|
443 |
-
self.register_buffer('v_first', v_first)
|
444 |
-
first_rwkv_layer = False
|
445 |
-
|
446 |
hidden_states = self.norm(hidden_states)
|
447 |
|
448 |
# add hidden states from the last decoder layer
|
@@ -473,8 +402,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
473 |
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
474 |
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
475 |
# to infer the attention mask.
|
476 |
-
past_seen_tokens = past_key_values.get_seq_length(
|
477 |
-
) if past_key_values is not None else 0
|
478 |
using_static_cache = isinstance(past_key_values, StaticCache)
|
479 |
|
480 |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
@@ -519,8 +447,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
519 |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
520 |
# Details: https://github.com/pytorch/pytorch/issues/110213
|
521 |
min_dtype = torch.finfo(dtype).min
|
522 |
-
causal_mask = AttentionMaskConverter._unmask_unattended(
|
523 |
-
causal_mask, min_dtype)
|
524 |
|
525 |
return causal_mask
|
526 |
|
@@ -563,20 +490,16 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
563 |
else:
|
564 |
min_dtype = torch.finfo(dtype).min
|
565 |
causal_mask = torch.full(
|
566 |
-
(sequence_length,
|
567 |
-
target_length), fill_value=min_dtype, dtype=dtype, device=device
|
568 |
)
|
569 |
if sequence_length != 1:
|
570 |
causal_mask = torch.triu(causal_mask, diagonal=1)
|
571 |
-
causal_mask *= torch.arange(target_length,
|
572 |
-
|
573 |
-
causal_mask = causal_mask[None, None,
|
574 |
-
:, :].expand(batch_size, 1, -1, -1)
|
575 |
if attention_mask is not None:
|
576 |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
577 |
mask_length = attention_mask.shape[-1]
|
578 |
-
padding_mask = causal_mask[:, :, :,
|
579 |
-
:mask_length] + attention_mask[:, None, None, :]
|
580 |
padding_mask = padding_mask == 0
|
581 |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
582 |
padding_mask, min_dtype
|
@@ -585,9 +508,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
|
|
585 |
return causal_mask
|
586 |
|
587 |
|
588 |
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
589 |
-
...
|
590 |
-
|
591 |
|
592 |
class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
593 |
_tied_weights_keys = ["lm_head.weight"]
|
@@ -597,8 +518,7 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
|
597 |
super().__init__(config)
|
598 |
self.model = RwkvHybridModel(config)
|
599 |
self.vocab_size = config.vocab_size
|
600 |
-
self.lm_head = nn.Linear(
|
601 |
-
config.hidden_size, config.vocab_size, bias=False)
|
602 |
|
603 |
# Initialize weights and apply final processing
|
604 |
self.post_init()
|
@@ -628,8 +548,7 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
|
628 |
input_ids: torch.LongTensor = None,
|
629 |
attention_mask: Optional[torch.Tensor] = None,
|
630 |
position_ids: Optional[torch.LongTensor] = None,
|
631 |
-
past_key_values: Optional[Union[Cache,
|
632 |
-
List[torch.FloatTensor]]] = None,
|
633 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
634 |
labels: Optional[torch.LongTensor] = None,
|
635 |
use_cache: Optional[bool] = None,
|
@@ -692,15 +611,12 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
|
692 |
)
|
693 |
|
694 |
hidden_states = outputs[0]
|
695 |
-
# Only compute necessary logits,
|
696 |
-
# and do not upcast them to float if we are not computing the loss
|
697 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
698 |
|
699 |
loss = None
|
700 |
if labels is not None:
|
701 |
-
loss = self.loss_function(
|
702 |
-
logits=logits, labels=labels,
|
703 |
-
vocab_size=self.config.vocab_size, **kwargs)
|
704 |
|
705 |
if not return_dict:
|
706 |
output = (logits,) + outputs[1:]
|
@@ -713,3 +629,4 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
|
713 |
hidden_states=outputs.hidden_states,
|
714 |
attentions=outputs.attentions,
|
715 |
)
|
|
|
|
1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from transformers.cache_utils import Cache
|
6 |
|
7 |
+
from transformers.activations import ACT2FN
|
8 |
+
from transformers.cache_utils import Cache, StaticCache
|
9 |
from .hybrid_cache import HybridCache
|
10 |
+
from transformers.generation import GenerationMixin
|
11 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
12 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
13 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
14 |
|
15 |
+
from transformers.modeling_outputs import (
|
16 |
BaseModelOutputWithPast,
|
17 |
CausalLMOutputWithPast,
|
18 |
)
|
19 |
+
from transformers.processing_utils import Unpack
|
20 |
+
from transformers.utils import (
|
21 |
LossKwargs,
|
22 |
add_start_docstrings,
|
23 |
add_start_docstrings_to_model_forward,
|
|
|
28 |
from .wkv import Rwkv7Attention, Rwkv6Attention
|
29 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
30 |
|
31 |
+
from transformers.models.qwen2.modeling_qwen2 import (Qwen2MLP,
|
32 |
+
Qwen2RMSNorm,
|
33 |
+
Qwen2RotaryEmbedding,
|
34 |
Qwen2Attention)
|
35 |
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
38 |
_CONFIG_FOR_DOC = "RwkvHybridConfig"
|
39 |
|
|
|
40 |
class RwkvHybridDecoderLayer(nn.Module):
|
41 |
+
def __init__(self, config: RwkvHybridConfig, layer_idx: int, update_v_first, get_v_first):
|
42 |
super().__init__()
|
43 |
self.hidden_size = config.hidden_size
|
44 |
|
45 |
self.is_rwkv = True if layer_idx in config.wkv_layers else False
|
46 |
if self.is_rwkv:
|
47 |
if config.wkv_version == 7:
|
48 |
+
self.self_attn = Rwkv7Attention(args=config, layer_id=layer_idx,
|
49 |
+
update_v_first=update_v_first,
|
50 |
+
get_v_first=get_v_first)
|
51 |
elif config.wkv_version == 6:
|
52 |
+
self.self_attn = Rwkv6Attention(args=config, layer_id=layer_idx,
|
53 |
+
update_v_first=update_v_first,
|
54 |
+
get_v_first=get_v_first)
|
55 |
else:
|
56 |
raise NotImplementedError
|
57 |
+
elif not self.is_rwkv:
|
58 |
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
|
59 |
+
else:
|
60 |
+
self.self_attn = None
|
61 |
+
raise NotImplementedError
|
62 |
|
63 |
self.mlp = Qwen2MLP(config)
|
64 |
self.input_layernorm = Qwen2RMSNorm(
|
65 |
config.hidden_size, eps=config.rms_norm_eps)
|
66 |
self.post_attention_layernorm = Qwen2RMSNorm(
|
67 |
+
config.hidden_size, eps=config.rms_norm_eps)
|
|
|
68 |
|
69 |
+
|
70 |
def forward(
|
71 |
self,
|
72 |
hidden_states: torch.Tensor,
|
73 |
attention_mask: Optional[torch.Tensor] = None,
|
74 |
+
position_ids: Optional[torch.LongTensor] = None,
|
75 |
past_key_value: Optional[Cache] = None,
|
76 |
output_attentions: Optional[bool] = False,
|
77 |
use_cache: Optional[bool] = False,
|
78 |
+
cache_position: Optional[torch.LongTensor] = None,
|
79 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
**kwargs,
|
81 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
residual = hidden_states
|
83 |
|
84 |
hidden_states = self.input_layernorm(hidden_states)
|
85 |
|
86 |
# RWKV attention
|
87 |
+
hidden_states, self_attn_weights = self.self_attn(
|
88 |
+
hidden_states=hidden_states,
|
89 |
+
attention_mask=attention_mask,
|
90 |
+
position_ids=position_ids,
|
91 |
+
past_key_value=past_key_value,
|
92 |
+
output_attentions=output_attentions,
|
93 |
+
use_cache=use_cache,
|
94 |
+
cache_position=cache_position,
|
95 |
+
position_embeddings=position_embeddings,
|
96 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
hidden_states = residual + hidden_states
|
98 |
|
99 |
# Fully Connected
|
|
|
106 |
if output_attentions:
|
107 |
outputs += (self_attn_weights,)
|
108 |
|
|
|
|
|
|
|
109 |
return outputs
|
110 |
|
|
|
111 |
RWKV_HYBRID_START_DOCSTRING = r"""
|
112 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
113 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
|
124 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
125 |
"""
|
126 |
|
|
|
127 |
@add_start_docstrings(
|
128 |
"The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
|
129 |
RWKV_HYBRID_START_DOCSTRING,
|
|
|
146 |
if module.padding_idx is not None:
|
147 |
module.weight.data[module.padding_idx].zero_()
|
148 |
|
|
|
149 |
RWKV_HYBRID_INPUTS_DOCSTRING = r"""
|
150 |
Args:
|
151 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
|
238 |
self.padding_idx = config.pad_token_id
|
239 |
self.vocab_size = config.vocab_size
|
240 |
|
241 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
242 |
self.thread_local = threading.local()
|
243 |
self.thread_local.v_first = None
|
244 |
self.layers = nn.ModuleList(
|
245 |
+
[RwkvHybridDecoderLayer(config, layer_idx, self.update_v_first, self.get_v_first) for layer_idx in range(config.num_hidden_layers)]
|
|
|
246 |
)
|
247 |
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
248 |
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
|
|
266 |
for layer in self.layers:
|
267 |
layer.self_attn.time_mixer.post_init()
|
268 |
|
269 |
+
def update_v_first(self, new_v_first):
|
270 |
+
"""Callback function to update v_first in HybridModel."""
|
271 |
+
self.thread_local.v_first = new_v_first
|
272 |
+
|
273 |
+
def get_v_first(self):
|
274 |
+
return self.thread_local.v_first
|
275 |
+
|
276 |
def get_input_embeddings(self):
|
277 |
return self.embed_tokens
|
278 |
|
279 |
def set_input_embeddings(self, value):
|
280 |
self.embed_tokens = value
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
@add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
|
283 |
def forward(
|
284 |
self,
|
|
|
292 |
output_hidden_states: Optional[bool] = None,
|
293 |
return_dict: Optional[bool] = None,
|
294 |
cache_position: Optional[torch.LongTensor] = None,
|
295 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
|
|
|
|
|
|
|
|
|
|
296 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
297 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
298 |
output_hidden_states = (
|
|
|
302 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
303 |
|
304 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
305 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
306 |
|
307 |
if self.gradient_checkpointing and self.training and use_cache:
|
308 |
logger.warning_once(
|
|
|
317 |
past_key_values = HybridCache()
|
318 |
|
319 |
if cache_position is None:
|
320 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
321 |
cache_position = torch.arange(
|
322 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
323 |
)
|
|
|
339 |
all_self_attns = () if output_attentions else None
|
340 |
|
341 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
|
342 |
if output_hidden_states:
|
343 |
all_hidden_states += (hidden_states,)
|
344 |
|
|
|
353 |
use_cache,
|
354 |
cache_position,
|
355 |
position_embeddings,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
)
|
357 |
else:
|
358 |
layer_outputs = decoder_layer(
|
|
|
364 |
use_cache=use_cache,
|
365 |
cache_position=cache_position,
|
366 |
position_embeddings=position_embeddings,
|
367 |
+
**flash_attn_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
)
|
369 |
|
370 |
hidden_states = layer_outputs[0]
|
|
|
372 |
if output_attentions:
|
373 |
all_self_attns += (layer_outputs[1],)
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
hidden_states = self.norm(hidden_states)
|
376 |
|
377 |
# add hidden states from the last decoder layer
|
|
|
402 |
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
403 |
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
404 |
# to infer the attention mask.
|
405 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
406 |
using_static_cache = isinstance(past_key_values, StaticCache)
|
407 |
|
408 |
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
|
|
447 |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
448 |
# Details: https://github.com/pytorch/pytorch/issues/110213
|
449 |
min_dtype = torch.finfo(dtype).min
|
450 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
|
451 |
|
452 |
return causal_mask
|
453 |
|
|
|
490 |
else:
|
491 |
min_dtype = torch.finfo(dtype).min
|
492 |
causal_mask = torch.full(
|
493 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
|
|
494 |
)
|
495 |
if sequence_length != 1:
|
496 |
causal_mask = torch.triu(causal_mask, diagonal=1)
|
497 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
498 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
|
|
|
499 |
if attention_mask is not None:
|
500 |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
501 |
mask_length = attention_mask.shape[-1]
|
502 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
|
503 |
padding_mask = padding_mask == 0
|
504 |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
505 |
padding_mask, min_dtype
|
|
|
508 |
return causal_mask
|
509 |
|
510 |
|
511 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
|
|
|
512 |
|
513 |
class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
|
514 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
518 |
super().__init__(config)
|
519 |
self.model = RwkvHybridModel(config)
|
520 |
self.vocab_size = config.vocab_size
|
521 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
522 |
|
523 |
# Initialize weights and apply final processing
|
524 |
self.post_init()
|
|
|
548 |
input_ids: torch.LongTensor = None,
|
549 |
attention_mask: Optional[torch.Tensor] = None,
|
550 |
position_ids: Optional[torch.LongTensor] = None,
|
551 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
|
552 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
553 |
labels: Optional[torch.LongTensor] = None,
|
554 |
use_cache: Optional[bool] = None,
|
|
|
611 |
)
|
612 |
|
613 |
hidden_states = outputs[0]
|
614 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
|
615 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
616 |
|
617 |
loss = None
|
618 |
if labels is not None:
|
619 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
|
|
|
620 |
|
621 |
if not return_dict:
|
622 |
output = (logits,) + outputs[1:]
|
|
|
629 |
hidden_states=outputs.hidden_states,
|
630 |
attentions=outputs.attentions,
|
631 |
)
|
632 |
+
|
wkv.py
CHANGED
@@ -279,8 +279,7 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
279 |
self.a0 + (xa @ self.a1) @ self.a2
|
280 |
) # a is "in-context learning rate"
|
281 |
if self.args.wkv_has_gate:
|
282 |
-
|
283 |
-
g = 1.0 + g_delta
|
284 |
kk = k * self.k_k
|
285 |
kk = F.normalize(kk.view(B, T, self.n_head, -1),
|
286 |
p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
|
|
|
279 |
self.a0 + (xa @ self.a1) @ self.a2
|
280 |
) # a is "in-context learning rate"
|
281 |
if self.args.wkv_has_gate:
|
282 |
+
g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
|
|
|
283 |
kk = k * self.k_k
|
284 |
kk = F.normalize(kk.view(B, T, self.n_head, -1),
|
285 |
p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
|