zhiyuan8 commited on
Commit
353a1a6
·
verified ·
1 Parent(s): b087ae1

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_rwkv_hybrid.py +252 -0
  2. hybrid_cache.py +154 -0
  3. wkv.py +604 -0
configuration_rwkv_hybrid.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 RWKV team. All rights reserved.
3
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """RwkvHybrid model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.modeling_rope_utils import rope_config_validation
20
+ from transformers.utils import logging
21
+ from typing import Optional, Union, List
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RwkvHybridConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`RwkvHybridModel`]. It is used to instantiate a
30
+ RwkvHybrid model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of
32
+ RwkvHybrid-7B-beta.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 151936):
40
+ Vocabulary size of the RwkvHybrid model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`RwkvHybridModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 22016):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*, defaults to 32):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
60
+ The maximum sequence length that this model might ever be used with.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
69
+ Whether the model's input and output word embeddings should be tied.
70
+ rope_theta (`float`, *optional*, defaults to 10000.0):
71
+ The base period of the RoPE embeddings.
72
+ rope_scaling (`Dict`, *optional*):
73
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
74
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
75
+ accordingly.
76
+ Expected contents:
77
+ `rope_type` (`str`):
78
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
79
+ 'llama3'], with 'default' being the original RoPE implementation.
80
+ `factor` (`float`, *optional*):
81
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
82
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
83
+ original maximum pre-trained length.
84
+ `original_max_position_embeddings` (`int`, *optional*):
85
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
86
+ pretraining.
87
+ `attention_factor` (`float`, *optional*):
88
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
89
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
90
+ `factor` field to infer the suggested value.
91
+ `beta_fast` (`float`, *optional*):
92
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
93
+ ramp function. If unspecified, it defaults to 32.
94
+ `beta_slow` (`float`, *optional*):
95
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
96
+ ramp function. If unspecified, it defaults to 1.
97
+ `short_factor` (`List[float]`, *optional*):
98
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
99
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
100
+ size divided by the number of attention heads divided by 2
101
+ `long_factor` (`List[float]`, *optional*):
102
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
103
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
104
+ size divided by the number of attention heads divided by 2
105
+ `low_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
107
+ `high_freq_factor` (`float`, *optional*):
108
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
109
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
110
+ Whether to use sliding window attention.
111
+ sliding_window (`int`, *optional*, defaults to 4096):
112
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
113
+ max_window_layers (`int`, *optional*, defaults to 28):
114
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
115
+ attention_dropout (`float`, *optional*, defaults to 0.0):
116
+ The dropout ratio for the attention probabilities.
117
+ head_size (`int`, *optional*, defaults to 64):
118
+ Dimensionality of each RWKV attention head. Defines the hidden dimension size for RWKV attention mechanisms.
119
+ head_size_divisor (`int`, *optional*, defaults to 8):
120
+ Constraint for head_size initialization, typically set to the square root of head_size. Ensures divisibility
121
+ between hidden_size and head_size.
122
+ wkv_version (`int`, *optional*, defaults to 7):
123
+ Version of RWKV attention implementation. Currently supports:
124
+ - 6: Original implementation requiring `wkv_has_gate=True` and `wkv_use_vfirst=False`
125
+ - 7: Improved version requiring `wkv_use_vfirst=True`
126
+ wkv_has_gate (`bool`, *optional*, defaults to False):
127
+ Whether to include gating mechanism in RWKV attention. Required for version 6.
128
+ wkv_has_group_norm (`bool`, *optional*, defaults to True):
129
+ Whether to apply group normalization in RWKV attention layers.
130
+ wkv_use_vfirst (`bool`, *optional*, defaults to True):
131
+ Whether to prioritize value projection in RWKV attention computation. Required for version 7.
132
+ wkv_layers (`Union[str, List[int]]`, *optional*, defaults to None):
133
+ Specifies which layers use RWKV attention:
134
+ - `"full"` or `None`: All layers use RWKV
135
+ - List of integers: Only specified layers (e.g., `[0,1,2]`) use RWKV attention
136
+
137
+ ```python
138
+ >>> from transformers import RwkvHybridModel, RwkvHybridConfig
139
+
140
+ >>> # Initializing a RwkvHybrid style configuration
141
+ >>> configuration = RwkvHybridConfig()
142
+
143
+ >>> # Initializing a model from the RwkvHybrid-7B style configuration
144
+ >>> model = RwkvHybridModel(configuration)
145
+
146
+ >>> # Accessing the model configuration
147
+ >>> configuration = model.config
148
+ ```"""
149
+
150
+ model_type = "rwkv_hybrid"
151
+ keys_to_ignore_at_inference = ["past_key_values"]
152
+
153
+ # Default tensor parallel plan for base model `RwkvHybrid`
154
+ base_model_tp_plan = {
155
+ "layers.*.self_attn.q_proj": "colwise",
156
+ "layers.*.self_attn.k_proj": "colwise",
157
+ "layers.*.self_attn.v_proj": "colwise",
158
+ "layers.*.self_attn.o_proj": "rowwise",
159
+ "layers.*.mlp.gate_proj": "colwise",
160
+ "layers.*.mlp.up_proj": "colwise",
161
+ "layers.*.mlp.down_proj": "rowwise",
162
+ }
163
+
164
+ def __init__(
165
+ self,
166
+ vocab_size: int = 151936,
167
+ hidden_size: int = 4096,
168
+ intermediate_size: int = 22016,
169
+ num_hidden_layers: int = 32,
170
+ num_attention_heads: int = 32,
171
+ num_key_value_heads: int = 32,
172
+ head_size: int = 64,
173
+ head_size_divisor: int = 8,
174
+ hidden_act: str = "silu",
175
+ max_position_embeddings: int = 32768,
176
+ initializer_range: float = 0.02,
177
+ rms_norm_eps: float = 1e-6,
178
+ use_cache: bool = True,
179
+ tie_word_embeddings: bool = False,
180
+ rope_theta: float = 10000.0,
181
+ rope_scaling: Optional[dict] = None,
182
+ use_sliding_window: bool = False,
183
+ sliding_window: int = 4096,
184
+ max_window_layers: int = 28,
185
+ attention_dropout: float = 0.0,
186
+ wkv_version: int = 7,
187
+ wkv_has_gate: bool = False,
188
+ wkv_has_group_norm: bool = True,
189
+ wkv_use_vfirst: bool = True,
190
+ wkv_layers: Optional[Union[str, List[int]]] = None,
191
+ **kwargs,
192
+ ):
193
+ self.vocab_size = vocab_size
194
+ self.max_position_embeddings = max_position_embeddings
195
+ self.hidden_size = hidden_size
196
+ self.intermediate_size = intermediate_size
197
+ self.num_hidden_layers = num_hidden_layers
198
+ self.num_wkv_heads = hidden_size // head_size
199
+ assert hidden_size % head_size == 0, "hidden_size must be divisible by head_size"
200
+ self.num_attention_heads = num_attention_heads
201
+ self.use_sliding_window = use_sliding_window
202
+ self.sliding_window = sliding_window if use_sliding_window else None
203
+ self.max_window_layers = max_window_layers
204
+ self.head_size = head_size
205
+ self.head_size_divisor = head_size_divisor
206
+ self.wkv_version = wkv_version
207
+
208
+ self.wkv_has_gate = wkv_has_gate
209
+ self.wkv_has_group_norm = wkv_has_group_norm
210
+ self.wkv_use_vfirst = wkv_use_vfirst
211
+
212
+ if self.wkv_version == 7:
213
+ assert self.wkv_use_vfirst, "wkv_use_vfirst must be True for wkv_version 7"
214
+ elif self.wkv_version == 6:
215
+ assert self.wkv_has_gate, "wkv_has_gate must be True for wkv_version 6"
216
+ assert not self.wkv_use_vfirst, "wkv_use_vfirst must be False for wkv_version 6"
217
+ else:
218
+ raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
219
+ wkv_version must be 6 or 7")
220
+
221
+ if wkv_layers == "full" or wkv_layers == None:
222
+ self.wkv_layers = list(range(num_hidden_layers))
223
+ elif isinstance(wkv_layers, list):
224
+ if all(isinstance(layer, int) for layer in wkv_layers):
225
+ self.wkv_layers = wkv_layers
226
+ else:
227
+ raise ValueError("All elements in wkv_layers must be integers.")
228
+ else:
229
+ raise TypeError("wkv_layers must be either 'full', None, or a list of integers.")
230
+
231
+ # for backward compatibility
232
+ if num_key_value_heads is None:
233
+ num_key_value_heads = num_attention_heads
234
+
235
+ self.num_key_value_heads = num_key_value_heads
236
+ self.hidden_act = hidden_act
237
+ self.initializer_range = initializer_range
238
+ self.rms_norm_eps = rms_norm_eps
239
+ self.use_cache = use_cache
240
+ self.rope_theta = rope_theta
241
+ self.rope_scaling = rope_scaling
242
+ self.attention_dropout = attention_dropout
243
+ # Validate the correctness of rotary position embeddings parameters
244
+ # BC: if there is a 'type' field, move it to 'rope_type'.
245
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
246
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
247
+ rope_config_validation(self)
248
+
249
+ super().__init__(
250
+ tie_word_embeddings=tie_word_embeddings,
251
+ **kwargs,
252
+ )
hybrid_cache.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Dict, Optional, Union
3
+ from transformers.cache_utils import DynamicCache
4
+
5
+
6
+ class TimeMixState:
7
+ def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
8
+ self.shift_state = shift_state
9
+ self.wkv_state = wkv_state
10
+
11
+
12
+ class ChannelMixState:
13
+ def __init__(self, shift_state: torch.Tensor):
14
+ self.shift_state = shift_state
15
+
16
+
17
+ class BlockState:
18
+ def __init__(self, time_mix_state: TimeMixState,
19
+ channel_mix_state: ChannelMixState):
20
+ self.time_mix_state = time_mix_state
21
+ self.channel_mix_state = channel_mix_state
22
+
23
+
24
+ class BlockStateList:
25
+ def __init__(self, shift_states, wkv_states):
26
+ self.wkv_states = wkv_states
27
+ self.shift_states = shift_states
28
+
29
+ @staticmethod
30
+ def create(N, B, C, H, device, dtype):
31
+ result = BlockStateList.empty(N, B, C, H, device, dtype)
32
+ result.wkv_states[:] = 0
33
+ result.wkv_states[:] = 0
34
+ result.shift_states[:] = 0
35
+ return result
36
+
37
+ @staticmethod
38
+ def empty(N, B, C, H, device, dtype):
39
+ wkv_states = torch.empty((N, B, H, C//H, C//H),
40
+ device=device,
41
+ dtype=torch.bfloat16)
42
+ shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
43
+ return BlockStateList(shift_states, wkv_states)
44
+
45
+ def __getitem__(self, layer: int):
46
+ return BlockState(
47
+ TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
48
+ ChannelMixState(self.shift_states[layer, 1]))
49
+
50
+ def __setitem__(self, layer: int, state: BlockState):
51
+ self.shift_states[layer, 0] = state.time_mix_state.shift_state
52
+ self.wkv_states[layer] = state.time_mix_state.wkv_state
53
+ self.shift_states[layer, 1] = state.channel_mix_state.shift_state
54
+
55
+
56
+ class HybridCache(DynamicCache):
57
+ def __init__(self) -> None:
58
+ super().__init__()
59
+ self.rwkv_layers = set()
60
+
61
+ def __repr__(self) -> str:
62
+ rwkv_layers = f"HybridCache(rwkv_layers={self.rwkv_layers})"
63
+ # count the number of key_cache and value_cache
64
+ key_cache_count = sum(len(cache) for cache in self.key_cache)
65
+ value_cache_count = sum(len(cache) for cache in self.value_cache)
66
+ count_info = rwkv_layers + \
67
+ f", key_cache_count={key_cache_count}, value_cache_count={value_cache_count}"
68
+ memories = 0
69
+ seq_length = self.get_seq_length()
70
+ for cache in self.value_cache:
71
+ for data in cache:
72
+ if not isinstance(data, torch.Tensor):
73
+ memories += data.time_mix_state.wkv_state.numel()
74
+ else:
75
+ memories += data.numel()
76
+ count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
77
+ return count_info
78
+
79
+ def update(self,
80
+ key_states: Union[int, torch.Tensor],
81
+ value_states: Union[torch.Tensor, BlockState],
82
+ layer_idx: int,
83
+ cache_kwargs: Optional[Dict[str, Any]] = None):
84
+ if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
85
+ self.rwkv_layers.add(layer_idx)
86
+ if layer_idx >= len(self.key_cache):
87
+ self.key_cache.append([])
88
+ self.value_cache.append([])
89
+
90
+ if len(self.key_cache[layer_idx]) == 0:
91
+ self.key_cache[layer_idx].append(key_states)
92
+ self.value_cache[layer_idx].append(value_states)
93
+ else:
94
+ self.key_cache[layer_idx][0] = self.key_cache[layer_idx][0]+key_states
95
+ self.value_cache[layer_idx][0] = value_states
96
+
97
+ return key_states, value_states
98
+
99
+ return super().update(key_states, value_states, layer_idx, cache_kwargs)
100
+
101
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
102
+ if layer_idx in self.rwkv_layers:
103
+ return self.key_cache[layer_idx][0]
104
+ return super().get_seq_length(layer_idx)
105
+
106
+ def get_max_length(self):
107
+ return super().get_max_length()
108
+
109
+ def reorder_cache(self, beam_idx):
110
+ return super().reorder_cache(beam_idx)
111
+
112
+ def __getitem__(self, item):
113
+ if item in self.rwkv_layers:
114
+ return self.value_cache[item]
115
+ return super().__getitem__(item)
116
+
117
+ def offload_to_cpu(self):
118
+ for cache in self.value_cache:
119
+ for data in cache:
120
+ if isinstance(data, torch.Tensor):
121
+ data.cpu()
122
+ else:
123
+ data.time_mix_state.wkv_state.cpu()
124
+ data.time_mix_state.shift_state.cpu()
125
+
126
+ def offload_to_cuda(self, device: str):
127
+ for cache in self.value_cache:
128
+ for data in cache:
129
+ if isinstance(data, torch.Tensor):
130
+ data.cuda(device)
131
+ else:
132
+ data.time_mix_state.wkv_state.cuda(device)
133
+ data.time_mix_state.shift_state.cuda(device)
134
+
135
+ def offload_to_device(self, device_type: str, device_id: int = 0):
136
+ for cache in self.value_cache:
137
+ for data in cache:
138
+ if isinstance(data, torch.Tensor):
139
+ method = getattr(data, device_type)
140
+ if device_type == 'cpu':
141
+ method()
142
+ else:
143
+ method(device_id)
144
+ else:
145
+ wkv_state_method = getattr(
146
+ data.time_mix_state.wkv_state, device_type)
147
+ shift_state_method = getattr(
148
+ data.time_mix_state.shift_state, device_type)
149
+ if device_type == 'cpu':
150
+ wkv_state_method()
151
+ shift_state_method()
152
+ else:
153
+ wkv_state_method(device_id)
154
+ shift_state_method(device_id)
wkv.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ import math
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from .configuration_rwkv_hybrid import RwkvHybridConfig
8
+ from typing import Optional
9
+ from .hybrid_cache import HybridCache, AttnState, BlockState
10
+
11
+ try:
12
+ import triton # pylint: disable=F401
13
+ from rwkvfla.ops.rwkv7 import (
14
+ fused_recurrent_rwkv7,
15
+ chunk_rwkv7,
16
+ native_recurrent_rwkv7,
17
+ fused_addcmul_rwkv7,
18
+ ) # pylint: disable=C0411
19
+ from rwkvfla.ops.rwkv6 import (
20
+ fused_recurrent_rwkv6,
21
+ chunk_rwkv6,
22
+ native_recurrent_rwkv6,
23
+ )
24
+ except ImportError:
25
+ from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
26
+ from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
27
+ from rwkvfla.ops.rwkv7 import torch_addcmul_rwkv7
28
+
29
+ fused_recurrent_rwkv7 = native_recurrent_rwkv7
30
+ chunk_rwkv7 = native_recurrent_rwkv7
31
+ chunk_rwkv6 = native_recurrent_rwkv6
32
+ fused_recurrent_rwkv6 = native_recurrent_rwkv6
33
+ fused_addcmul_rwkv7 = torch_addcmul_rwkv7
34
+
35
+ from rwkvfla.utils import check_pytorch_version
36
+
37
+ if check_pytorch_version("2.6"):
38
+ compile_decorator = torch.compile
39
+ torch._dynamo.config.cache_size_limit = 512
40
+ else:
41
+ def compile_decorator(func):
42
+ return func
43
+
44
+
45
+ class Rwkv_Tmix_x070(nn.Module):
46
+ def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
47
+ super().__init__()
48
+ self.args = args
49
+ self.layer_id = layer_id
50
+ self.hidden_size = args.hidden_size
51
+
52
+ self.head_size = args.head_size
53
+ self.n_head = args.num_wkv_heads
54
+ assert args.hidden_size % self.n_head == 0
55
+ H = self.n_head
56
+ N = self.head_size
57
+
58
+ self.x_r = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
59
+ self.x_w = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
60
+ self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
61
+ self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
62
+ self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
63
+
64
+ D_DECAY_LORA = 64
65
+ D_AAA_LORA = 64
66
+ D_MV_LORA = 32
67
+ D_GATE_LORA = 128
68
+
69
+ self.w1 = nn.Parameter(torch.Tensor(args.hidden_size, D_DECAY_LORA))
70
+ self.w2 = nn.Parameter(torch.Tensor(D_DECAY_LORA, args.hidden_size))
71
+ self.w0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
72
+
73
+ self.a1 = nn.Parameter(torch.Tensor(args.hidden_size, D_AAA_LORA))
74
+ self.a2 = nn.Parameter(torch.Tensor(D_AAA_LORA, args.hidden_size))
75
+ self.a0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
76
+
77
+ self.v1 = nn.Parameter(torch.Tensor(args.hidden_size, D_MV_LORA))
78
+ self.v2 = nn.Parameter(torch.Tensor(D_MV_LORA, args.hidden_size))
79
+ self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
80
+
81
+ if self.args.wkv_has_gate:
82
+ self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
83
+ self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
84
+ self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
85
+
86
+ self.k_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
87
+ self.k_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
88
+ self.r_k = nn.Parameter(torch.Tensor(H, N))
89
+
90
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
91
+ self.receptance = nn.Linear(
92
+ args.hidden_size, args.hidden_size, bias=False)
93
+ self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
94
+ self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
95
+ self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
96
+
97
+ if self.args.wkv_has_group_norm:
98
+ self.ln_x = nn.GroupNorm(
99
+ H, args.hidden_size, eps=(1e-5) * (args.head_size_divisor**2)
100
+ )
101
+
102
+ def post_init(self):
103
+ with torch.no_grad():
104
+ ratio_0_to_1 = self.layer_id / \
105
+ (self.args.num_hidden_layers - 1) # 0 to 1
106
+ ratio_1_to_almost0 = 1.0 - (
107
+ self.layer_id / self.args.num_hidden_layers
108
+ ) # 1 to ~0
109
+
110
+ ddd = torch.ones(1, 1, self.args.hidden_size)
111
+ for i in range(self.args.hidden_size):
112
+ ddd[0, 0, i] = i / self.args.hidden_size
113
+
114
+ nn.init.constant_(
115
+ self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
116
+ nn.init.constant_(
117
+ self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
118
+ nn.init.constant_(
119
+ self.x_k,
120
+ 1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) +
121
+ 0.4 * ratio_0_to_1),
122
+ )
123
+ nn.init.constant_(
124
+ self.x_v,
125
+ 1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) +
126
+ 0.6 * ratio_0_to_1),
127
+ )
128
+ nn.init.constant_(
129
+ self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
130
+
131
+ def ortho_init(x, scale):
132
+ shape = x.shape
133
+ original_dtype = x.dtype
134
+ x_fp32 = x.float()
135
+ if len(shape) == 2:
136
+ gain = math.sqrt(shape[0] / shape[1]
137
+ ) if shape[0] > shape[1] else 1
138
+ nn.init.orthogonal_(x_fp32, gain=gain * scale)
139
+ elif len(shape) == 3:
140
+ gain = math.sqrt(shape[1] / shape[2]
141
+ ) if shape[1] > shape[2] else 1
142
+ for i in range(shape[0]):
143
+ nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
144
+ else:
145
+ raise ValueError(
146
+ "ortho_init only supports 2D or 3D tensors")
147
+ x.data.copy_(x_fp32.to(original_dtype))
148
+ return x
149
+
150
+ D_DECAY_LORA = 64
151
+ nn.init.zeros_(self.w1)
152
+ self.w2 = nn.Parameter(
153
+ ortho_init(torch.zeros(
154
+ D_DECAY_LORA, self.args.hidden_size), 0.1)
155
+ )
156
+
157
+ decay_speed = torch.ones(self.args.hidden_size)
158
+ for n in range(self.args.hidden_size):
159
+ decay_speed[n] = -7 + 5 * (n / (self.args.hidden_size - 1)) ** (
160
+ 0.85 + 1.0 * ratio_0_to_1**0.5
161
+ )
162
+ nn.init.constant_(
163
+ self.w0, decay_speed.reshape(1, 1, self.args.hidden_size) + 0.5
164
+ )
165
+
166
+ D_AAA_LORA = 64
167
+ nn.init.zeros_(self.a1)
168
+ self.a2 = nn.Parameter(
169
+ ortho_init(torch.zeros(D_AAA_LORA, self.args.hidden_size), 0.1)
170
+ )
171
+ nn.init.zeros_(self.a0)
172
+
173
+ D_MV_LORA = 32
174
+ nn.init.zeros_(self.v1)
175
+ self.v2 = nn.Parameter(
176
+ ortho_init(torch.zeros(D_MV_LORA, self.args.hidden_size), 0.1)
177
+ )
178
+ nn.init.constant_(self.v0, 1.0)
179
+
180
+ D_GATE_LORA = 128
181
+ if self.args.wkv_has_gate:
182
+ nn.init.zeros_(self.g1)
183
+ self.g2 = nn.Parameter(
184
+ ortho_init(torch.zeros(
185
+ D_GATE_LORA, self.args.hidden_size), 0.1)
186
+ )
187
+ nn.init.constant_(
188
+ self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
189
+
190
+ nn.init.constant_(self.k_k, 0.85)
191
+ nn.init.constant_(self.k_a, 1.0)
192
+ nn.init.zeros_(self.r_k)
193
+
194
+ nn.init.zeros_(self.receptance.weight)
195
+ nn.init.zeros_(self.key.weight)
196
+ nn.init.zeros_(self.value.weight)
197
+ nn.init.zeros_(self.output.weight)
198
+
199
+ if self.args.wkv_has_group_norm:
200
+ nn.init.ones_(self.ln_x.weight)
201
+ nn.init.zeros_(self.ln_x.bias)
202
+
203
+ def apply_wkv7_state(
204
+ self, r, k, v, w, a, b, s,
205
+ output_final_state,
206
+ cu_seqlens
207
+ ):
208
+ if r.device.type == "cpu":
209
+ r, w, k, v, a, b = map(lambda x: rearrange(
210
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
211
+ o, state = native_recurrent_rwkv7(
212
+ r=r, k=k, v=v, w=w,
213
+ a=a, b=b,
214
+ scale=1.0,
215
+ initial_state=s.transpose(-1, -2),
216
+ output_final_state=True,
217
+ head_first=True,
218
+ )
219
+ state = state.transpose(-1, -2)
220
+ x = rearrange(o, "b h l d -> b l (h d)")
221
+ else:
222
+ r, w, k, v, a, b = map(lambda x: rearrange(
223
+ x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
224
+ wkv7_func = chunk_rwkv7 if r.shape[1] != 1 else fused_recurrent_rwkv7
225
+ o, state = wkv7_func(
226
+ r=r, k=k, v=v, w=w,
227
+ a=a, b=b,
228
+ scale=1.0,
229
+ initial_state=s,
230
+ output_final_state=output_final_state,
231
+ cu_seqlens=cu_seqlens,
232
+ head_first=False,
233
+ )
234
+ x = rearrange(o, "b l h d -> b l (h d)")
235
+ return x, state
236
+
237
+ @compile_decorator
238
+ def forward(
239
+ self,
240
+ hidden_states,
241
+ last_state: AttnState,
242
+ use_cache: Optional[bool] = False,
243
+ cu_seqlens: Optional[torch.Tensor] = None,
244
+ v_first: Optional[torch.Tensor] = None,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ **kwargs
247
+ ):
248
+ shift_state = last_state.shift_state
249
+ B, T, C = hidden_states.size()
250
+
251
+ xx = torch.concat((shift_state.unsqueeze(
252
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
253
+
254
+ lx = hidden_states[:, -1]
255
+
256
+ if self.args.wkv_has_gate:
257
+ xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
258
+ hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
259
+ else:
260
+ xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(
261
+ hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
262
+
263
+ r = self.receptance(xr)
264
+ w = (
265
+ -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5
266
+ ) # soft-clamp to (-inf, -0.5)
267
+ k = self.key(xk)
268
+ v = self.value(xv)
269
+ if self.layer_id == 0:
270
+ v_first = v
271
+ else:
272
+ v = torch.lerp(v, v_first, torch.sigmoid(
273
+ self.v0 + (xv @ self.v1) @ self.v2
274
+ )) # add value residual
275
+
276
+ if attention_mask is not None:
277
+ v = v.mul(attention_mask[:, -v.shape[-2]:, None])
278
+ a = torch.sigmoid(
279
+ self.a0 + (xa @ self.a1) @ self.a2
280
+ ) # a is "in-context learning rate"
281
+ if self.args.wkv_has_gate:
282
+ g_delta = torch.sigmoid(xg @ self.g1) @ self.g2
283
+ g = 1.0 + g_delta
284
+ kk = k * self.k_k
285
+ kk = F.normalize(kk.view(B, T, self.n_head, -1),
286
+ p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
287
+ k = torch.lerp(k, k * a, self.k_a)
288
+
289
+ wkv_state = last_state.wkv_state
290
+ hidden_states, wkv_state = self.apply_wkv7_state(
291
+ r,
292
+ k,
293
+ v,
294
+ w,
295
+ -kk,
296
+ (kk * a),
297
+ s=wkv_state,
298
+ output_final_state=use_cache,
299
+ cu_seqlens=cu_seqlens
300
+ )
301
+ if self.args.wkv_has_group_norm:
302
+ hidden_states = self.ln_x(
303
+ hidden_states.view(B * T, C)).view(B, T, C)
304
+
305
+ # original code:
306
+ # weighted_sum_rk = (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
307
+ # dim=-1, keepdim=True
308
+ # )
309
+ weighted_sum_rk = torch.einsum('btij,btij,ij->btij', r.view(B, T, self.n_head, -1),
310
+ k.view(B, T, self.n_head, -1), self.r_k).sum(dim=-1, keepdim=True)
311
+ hidden_states = hidden_states + \
312
+ (weighted_sum_rk * v.view(B, T, self.n_head, -1)).view(B, T, C)
313
+ hidden_states = self.output(
314
+ hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
315
+ return hidden_states, AttnState(lx, wkv_state), v_first
316
+
317
+
318
+ class Rwkv7Attention(nn.Module):
319
+ def __init__(self, args: RwkvHybridConfig, layer_id):
320
+ super().__init__()
321
+ self.args = args
322
+ self.layer_idx = layer_id
323
+ self.time_mixer = Rwkv_Tmix_x070(args, layer_id)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ position_ids: Optional[torch.Tensor] = None,
330
+ past_key_value: Optional[HybridCache] = None,
331
+ output_attentions: Optional[bool] = False,
332
+ use_cache: Optional[bool] = False,
333
+ cache_position: Optional[torch.Tensor] = None,
334
+ position_embeddings: Optional[torch.Tensor] = None,
335
+ cu_seqlens: Optional[torch.Tensor] = None,
336
+ v_first: Optional[torch.Tensor] = None,
337
+ **kwargs
338
+ ):
339
+
340
+ batch_size, token_length, _ = hidden_states.shape
341
+
342
+ if use_cache and len(past_key_value) > self.layer_idx:
343
+ last_state = past_key_value[self.layer_idx][0]
344
+ else:
345
+ last_state = self.init_state(
346
+ batch_size, hidden_states.device, hidden_states.dtype
347
+ )
348
+
349
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
350
+ last_state=last_state.attn_state,
351
+ use_cache=use_cache,
352
+ cu_seqlens=cu_seqlens,
353
+ v_first=v_first,
354
+ **kwargs)
355
+
356
+ if use_cache:
357
+ last_state.attn_state = states
358
+ past_key_value.update(token_length, last_state, self.layer_idx)
359
+
360
+ return attn_output, None, v_first
361
+
362
+ def init_state(self, batch_size, device, dtype) -> BlockState:
363
+ wkv_states = torch.zeros(
364
+ (
365
+ batch_size,
366
+ self.args.num_wkv_heads,
367
+ self.args.head_size,
368
+ self.args.head_size,
369
+ ),
370
+ device=device,
371
+ dtype=torch.float32,
372
+ )
373
+ shift_states = torch.zeros(
374
+ (batch_size, self.args.hidden_size), device=device, dtype=dtype
375
+ )
376
+ return BlockState(AttnState(shift_states, wkv_states), None)
377
+
378
+
379
+ class Rwkv_Tmix_x060(nn.Module):
380
+ def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
381
+ super().__init__()
382
+ self.args = args
383
+ self.layer_id = layer_id
384
+ self.hidden_size = args.hidden_size
385
+
386
+ self.head_size = args.head_size
387
+ self.n_head = args.num_wkv_heads
388
+ assert args.hidden_size % self.n_head == 0
389
+
390
+ with torch.no_grad():
391
+ ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
392
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
393
+ ddd = torch.ones(1, 1, args.hidden_size)
394
+ for i in range(args.hidden_size):
395
+ ddd[0, 0, i] = i / args.hidden_size
396
+
397
+ # fancy time_mix
398
+ self.time_maa_x = nn.Parameter(
399
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
400
+ self.time_maa_w = nn.Parameter(
401
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
402
+ self.time_maa_k = nn.Parameter(
403
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
404
+ self.time_maa_v = nn.Parameter(
405
+ 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
406
+ )
407
+ self.time_maa_r = nn.Parameter(
408
+ 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
409
+ )
410
+ self.time_maa_g = nn.Parameter(
411
+ 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
412
+ )
413
+
414
+ D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
415
+ if args.hidden_size == 4096:
416
+ D_MIX_LORA = D_MIX_LORA * 2
417
+ self.time_maa_w1 = nn.Parameter(
418
+ torch.zeros(args.hidden_size, D_MIX_LORA * 5)
419
+ )
420
+ self.time_maa_w2 = nn.Parameter(
421
+ torch.zeros(5, D_MIX_LORA,
422
+ args.hidden_size).uniform_(-0.01, 0.01)
423
+ )
424
+
425
+ # fancy time_decay
426
+ decay_speed = torch.ones(args.head_size)
427
+ for n in range(args.head_size):
428
+ decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
429
+ 0.7 + 1.3 * ratio_0_to_1
430
+ )
431
+ self.time_decay = nn.Parameter(
432
+ decay_speed.reshape(1, 1, args.head_size))
433
+
434
+ D_DECAY_LORA = 64
435
+ if args.hidden_size == 4096:
436
+ D_DECAY_LORA = D_DECAY_LORA * 2
437
+ self.time_decay_w1 = nn.Parameter(
438
+ torch.zeros(args.hidden_size, D_DECAY_LORA)
439
+ )
440
+ self.time_decay_w2 = nn.Parameter(
441
+ torch.zeros(D_DECAY_LORA, args.head_size).uniform_(-0.01, 0.01)
442
+ )
443
+
444
+ tmp = torch.zeros(args.head_size)
445
+ for n in range(args.head_size):
446
+ zigzag = ((n + 1) % 3 - 1) * 0.1
447
+ tmp[n] = ratio_0_to_1 * \
448
+ (1 - (n / (args.head_size - 1))) + zigzag
449
+
450
+ self.time_faaaa = nn.Parameter(
451
+ tmp.reshape(self.n_head, self.head_size))
452
+
453
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
454
+ self.receptance = nn.Linear(
455
+ args.hidden_size, args.head_size, bias=False)
456
+ self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
457
+
458
+ self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
459
+ self.output = nn.Linear(args.head_size, args.hidden_size, bias=False)
460
+ self.gate = nn.Linear(args.hidden_size, args.head_size, bias=False)
461
+
462
+ if self.args.wkv_has_group_norm:
463
+ self.ln_x = nn.GroupNorm(
464
+ self.n_head, args.head_size, eps=(
465
+ 1e-5) * (args.head_size_divisor**2)
466
+ )
467
+
468
+ def post_init(self):
469
+ pass
470
+
471
+ @compile_decorator
472
+ def forward(
473
+ self,
474
+ hidden_states,
475
+ last_state: AttnState,
476
+ use_cache: Optional[bool] = False,
477
+ cu_seqlens: Optional[torch.Tensor] = None,
478
+ v_first: Optional[torch.Tensor] = None,
479
+ **kwargs
480
+ ):
481
+ shift_state = last_state.shift_state
482
+ B, T, C = hidden_states.size()
483
+ H = self.n_head
484
+
485
+ xx = torch.concat((shift_state.unsqueeze(
486
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
487
+
488
+ lx = hidden_states[:, -1]
489
+
490
+ xxx = hidden_states + xx * self.time_maa_x
491
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
492
+ T, 5, -1).transpose(0, 1)
493
+ xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
494
+ mw, mk, mv, mr, mg = xxx.unbind(dim=0)
495
+
496
+ xw = hidden_states + xx * (self.time_maa_w + mw)
497
+ xk = hidden_states + xx * (self.time_maa_k + mk)
498
+ xv = hidden_states + xx * (self.time_maa_v + mv)
499
+ xr = hidden_states + xx * (self.time_maa_r + mr)
500
+ xg = hidden_states + xx * (self.time_maa_g + mg)
501
+
502
+ r = self.receptance(xr)
503
+ k = self.key(xk)
504
+ v = self.value(xv)
505
+ g = F.silu(self.gate(xg))
506
+
507
+ ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
508
+ w = self.time_decay + ww
509
+
510
+ wkv_state = last_state.wkv_state
511
+ hidden_states, wkv_state = self.apply_wkv6_state(
512
+ B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
513
+ )
514
+ if self.args.wkv_has_group_norm:
515
+ hidden_states = self.ln_x(
516
+ hidden_states.view(B * T, C)).view(B, T, C)
517
+ hidden_states = self.output(hidden_states * g)
518
+ return hidden_states, AttnState(lx, wkv_state), None
519
+
520
+ def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
521
+ r, w, k, v = map(lambda x: rearrange(
522
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
523
+
524
+ if r.device.type == "cpu":
525
+ wkv6_func = native_recurrent_rwkv6
526
+ elif self.training:
527
+ wkv6_func = chunk_rwkv6
528
+ else:
529
+ wkv6_func = fused_recurrent_rwkv6
530
+
531
+ o, state = wkv6_func(
532
+ r,
533
+ k,
534
+ v,
535
+ -torch.exp(w),
536
+ u=u,
537
+ scale=1.0,
538
+ initial_state=s,
539
+ output_final_state=True,
540
+ )
541
+ x = rearrange(o, "b h l d -> b l (h d)")
542
+ return x, state
543
+
544
+
545
+ class Rwkv6Attention(nn.Module):
546
+ def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
547
+ super().__init__()
548
+ self.args = args
549
+ self.layer_idx = layer_id
550
+ self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
551
+
552
+ def forward(
553
+ self,
554
+ hidden_states: torch.Tensor,
555
+ attention_mask: Optional[torch.Tensor] = None,
556
+ position_ids: Optional[torch.Tensor] = None,
557
+ past_key_value: Optional[HybridCache] = None,
558
+ output_attentions: Optional[bool] = False,
559
+ use_cache: Optional[bool] = False,
560
+ cache_position: Optional[torch.Tensor] = None,
561
+ position_embeddings: Optional[torch.Tensor] = None,
562
+ cu_seqlens: Optional[torch.Tensor] = None,
563
+ v_first: Optional[torch.Tensor] = None,
564
+ **kwargs
565
+ ):
566
+ attn_output = hidden_states
567
+
568
+ batch_size, token_length, _ = hidden_states.shape
569
+
570
+ if use_cache and len(past_key_value) > self.layer_idx:
571
+ last_state = past_key_value[self.layer_idx][0]
572
+ else:
573
+ last_state = self.init_state(
574
+ batch_size, hidden_states.device, hidden_states.dtype
575
+ )
576
+
577
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
578
+ last_state=last_state.attn_state,
579
+ use_cache=use_cache,
580
+ cu_seqlens=cu_seqlens,
581
+ v_first=v_first,
582
+ **kwargs)
583
+
584
+ if use_cache:
585
+ last_state.attn_state = states
586
+ past_key_value.update(token_length, last_state, self.layer_idx)
587
+
588
+ return attn_output, None, v_first
589
+
590
+ def init_state(self, batch_size, device, dtype) -> BlockState:
591
+ wkv_states = torch.zeros(
592
+ (
593
+ batch_size,
594
+ self.args.num_wkv_heads,
595
+ self.args.head_size,
596
+ self.args.head_size,
597
+ ),
598
+ device=device,
599
+ dtype=torch.float32,
600
+ )
601
+ shift_states = torch.zeros(
602
+ (batch_size, self.args.hidden_size), device=device, dtype=dtype
603
+ )
604
+ return BlockState(AttnState(shift_states, wkv_states), None)