zhiyuan8 commited on
Commit
3879d45
·
verified ·
1 Parent(s): 8e207ff

Upload 4 files

Browse files
Files changed (3) hide show
  1. configuration_rwkv_hybrid.py +8 -6
  2. modeling_rwkv_hybrid.py +64 -147
  3. wkv.py +1 -2
configuration_rwkv_hybrid.py CHANGED
@@ -15,9 +15,9 @@
15
  # limitations under the License.
16
  """RwkvHybrid model configuration"""
17
 
18
- from transformers.configuration_utils import PretrainedConfig
19
- from transformers.modeling_rope_utils import rope_config_validation
20
- from transformers.utils import logging
21
  from typing import Optional, Union, List
22
 
23
 
@@ -218,15 +218,17 @@ class RwkvHybridConfig(PretrainedConfig):
218
  raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
219
  wkv_version must be 6 or 7")
220
 
221
- if wkv_layers == "full" or wkv_layers == None:
222
  self.wkv_layers = list(range(num_hidden_layers))
223
  elif isinstance(wkv_layers, list):
224
  if all(isinstance(layer, int) for layer in wkv_layers):
225
  self.wkv_layers = wkv_layers
226
  else:
227
- raise ValueError("All elements in wkv_layers must be integers.")
 
228
  else:
229
- raise TypeError("wkv_layers must be either 'full', None, or a list of integers.")
 
230
 
231
  # for backward compatibility
232
  if num_key_value_heads is None:
 
15
  # limitations under the License.
16
  """RwkvHybrid model configuration"""
17
 
18
+ from ...configuration_utils import PretrainedConfig
19
+ from ...modeling_rope_utils import rope_config_validation
20
+ from ...utils import logging
21
  from typing import Optional, Union, List
22
 
23
 
 
218
  raise NotImplementedError(f"Unsupported wkv_version: {self.wkv_version}, \
219
  wkv_version must be 6 or 7")
220
 
221
+ if wkv_layers == "full" or wkv_layers is None:
222
  self.wkv_layers = list(range(num_hidden_layers))
223
  elif isinstance(wkv_layers, list):
224
  if all(isinstance(layer, int) for layer in wkv_layers):
225
  self.wkv_layers = wkv_layers
226
  else:
227
+ raise ValueError(
228
+ "All elements in wkv_layers must be integers.")
229
  else:
230
+ raise TypeError(
231
+ "wkv_layers must be either 'full', None, or a list of integers.")
232
 
233
  # for backward compatibility
234
  if num_key_value_heads is None:
modeling_rwkv_hybrid.py CHANGED
@@ -1,22 +1,23 @@
1
- from typing import List, Optional, Tuple, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  from transformers.cache_utils import Cache
6
 
7
- from ...cache_utils import Cache, StaticCache
 
8
  from .hybrid_cache import HybridCache
9
- from ...generation import GenerationMixin
10
- from ...modeling_attn_mask_utils import AttentionMaskConverter
11
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
12
- from ...modeling_utils import PreTrainedModel
13
 
14
- from ...modeling_outputs import (
15
  BaseModelOutputWithPast,
16
  CausalLMOutputWithPast,
17
  )
18
- from ...processing_utils import Unpack
19
- from ...utils import (
20
  LossKwargs,
21
  add_start_docstrings,
22
  add_start_docstrings_to_model_forward,
@@ -27,104 +28,72 @@ import threading
27
  from .wkv import Rwkv7Attention, Rwkv6Attention
28
  from .configuration_rwkv_hybrid import RwkvHybridConfig
29
 
30
- from ..qwen2.modeling_qwen2 import (Qwen2MLP,
31
- Qwen2RMSNorm,
32
- Qwen2RotaryEmbedding,
33
  Qwen2Attention)
34
 
35
  logger = logging.get_logger(__name__)
36
 
37
  _CONFIG_FOR_DOC = "RwkvHybridConfig"
38
 
39
-
40
  class RwkvHybridDecoderLayer(nn.Module):
41
- def __init__(self, config: RwkvHybridConfig, layer_idx: int):
42
  super().__init__()
43
  self.hidden_size = config.hidden_size
44
 
45
  self.is_rwkv = True if layer_idx in config.wkv_layers else False
46
  if self.is_rwkv:
47
  if config.wkv_version == 7:
48
- self.self_attn = Rwkv7Attention(
49
- args=config, layer_id=layer_idx)
 
50
  elif config.wkv_version == 6:
51
- self.self_attn = Rwkv6Attention(
52
- args=config, layer_id=layer_idx)
 
53
  else:
54
  raise NotImplementedError
55
- else:
56
  self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
 
 
 
57
 
58
  self.mlp = Qwen2MLP(config)
59
  self.input_layernorm = Qwen2RMSNorm(
60
  config.hidden_size, eps=config.rms_norm_eps)
61
  self.post_attention_layernorm = Qwen2RMSNorm(
62
- config.hidden_size, eps=config.rms_norm_eps)
63
- self.layer_idx = layer_idx
64
 
 
65
  def forward(
66
  self,
67
  hidden_states: torch.Tensor,
68
  attention_mask: Optional[torch.Tensor] = None,
69
- position_ids: Optional[torch.Tensor] = None,
70
  past_key_value: Optional[Cache] = None,
71
  output_attentions: Optional[bool] = False,
72
  use_cache: Optional[bool] = False,
73
- cache_position: Optional[torch.Tensor] = None,
74
- position_embeddings: Optional[torch.Tensor] = None,
75
- sequence_mask: Optional[torch.Tensor] = None,
76
- cu_seq_lens_q: Optional[torch.LongTensor] = None,
77
- cu_seq_lens_k: Optional[torch.LongTensor] = None,
78
- max_length_q: Optional[int] = None,
79
- max_length_k: Optional[int] = None,
80
- cu_seqlens: Optional[torch.LongTensor] = None,
81
- v_first: Optional[torch.LongTensor] = None,
82
  **kwargs,
83
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
84
-
85
- if sequence_mask is not None:
86
- assert len(sequence_mask.shape) == 2, (
87
- "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
88
- "for padding purposes (0 indicating padding). "
89
- "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
90
- )
91
- hidden_states = hidden_states.mul(
92
- sequence_mask[:, -hidden_states.shape[-2]:, None])
93
-
94
  residual = hidden_states
95
 
96
  hidden_states = self.input_layernorm(hidden_states)
97
 
98
  # RWKV attention
99
- if self.is_rwkv:
100
- hidden_states, self_attn_weights, v_first = self.self_attn(
101
- hidden_states=hidden_states,
102
- position_ids=position_ids,
103
- past_key_value=past_key_value,
104
- output_attentions=output_attentions,
105
- use_cache=use_cache,
106
- cache_position=cache_position,
107
- position_embeddings=position_embeddings,
108
- cu_seqlens=cu_seqlens,
109
- v_first=v_first,
110
- **kwargs
111
- )
112
- else:
113
- hidden_states, self_attn_weights = self.self_attn(
114
- hidden_states=hidden_states,
115
- attention_mask=attention_mask,
116
- position_ids=position_ids,
117
- past_key_value=past_key_value,
118
- output_attentions=output_attentions,
119
- use_cache=use_cache,
120
- cache_position=cache_position,
121
- position_embeddings=position_embeddings,
122
- cu_seq_lens_q=cu_seq_lens_q,
123
- cu_seq_lens_k=cu_seq_lens_k,
124
- max_length_q=max_length_q,
125
- max_length_k=max_length_k,
126
- **kwargs
127
- )
128
  hidden_states = residual + hidden_states
129
 
130
  # Fully Connected
@@ -137,12 +106,8 @@ class RwkvHybridDecoderLayer(nn.Module):
137
  if output_attentions:
138
  outputs += (self_attn_weights,)
139
 
140
- if self.is_rwkv:
141
- outputs += (v_first,)
142
-
143
  return outputs
144
 
145
-
146
  RWKV_HYBRID_START_DOCSTRING = r"""
147
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
148
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -159,7 +124,6 @@ RWKV_HYBRID_START_DOCSTRING = r"""
159
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
160
  """
161
 
162
-
163
  @add_start_docstrings(
164
  "The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
165
  RWKV_HYBRID_START_DOCSTRING,
@@ -182,7 +146,6 @@ class RwkvHybridPreTrainedModel(PreTrainedModel):
182
  if module.padding_idx is not None:
183
  module.weight.data[module.padding_idx].zero_()
184
 
185
-
186
  RWKV_HYBRID_INPUTS_DOCSTRING = r"""
187
  Args:
188
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -275,13 +238,11 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
275
  self.padding_idx = config.pad_token_id
276
  self.vocab_size = config.vocab_size
277
 
278
- self.embed_tokens = nn.Embedding(
279
- config.vocab_size, config.hidden_size, self.padding_idx)
280
  self.thread_local = threading.local()
281
  self.thread_local.v_first = None
282
  self.layers = nn.ModuleList(
283
- [RwkvHybridDecoderLayer(config, layer_idx)
284
- for layer_idx in range(config.num_hidden_layers)]
285
  )
286
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
287
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
@@ -305,20 +266,19 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
305
  for layer in self.layers:
306
  layer.self_attn.time_mixer.post_init()
307
 
 
 
 
 
 
 
 
308
  def get_input_embeddings(self):
309
  return self.embed_tokens
310
 
311
  def set_input_embeddings(self, value):
312
  self.embed_tokens = value
313
 
314
- def get_v_first(self, layer_idx: int, use_cache: bool, past_key_value: HybridCache):
315
- if layer_idx == 0:
316
- return None
317
-
318
- if use_cache:
319
- return past_key_value.get_v_first()
320
- return self.v_first
321
-
322
  @add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
323
  def forward(
324
  self,
@@ -332,12 +292,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
332
  output_hidden_states: Optional[bool] = None,
333
  return_dict: Optional[bool] = None,
334
  cache_position: Optional[torch.LongTensor] = None,
335
- cu_seq_lens_q: Optional[torch.LongTensor] = None,
336
- cu_seq_lens_k: Optional[torch.LongTensor] = None,
337
- max_length_q: Optional[int] = None,
338
- max_length_k: Optional[int] = None,
339
- cu_seqlens: Optional[torch.LongTensor] = None,
340
- **kwargs,
341
  ) -> Union[Tuple, BaseModelOutputWithPast]:
342
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
343
  output_hidden_states = (
@@ -347,8 +302,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
347
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
348
 
349
  if (input_ids is None) ^ (inputs_embeds is not None):
350
- raise ValueError(
351
- "You must specify exactly one of input_ids or inputs_embeds")
352
 
353
  if self.gradient_checkpointing and self.training and use_cache:
354
  logger.warning_once(
@@ -363,8 +317,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
363
  past_key_values = HybridCache()
364
 
365
  if cache_position is None:
366
- past_seen_tokens = past_key_values.get_seq_length(
367
- ) if past_key_values is not None else 0
368
  cache_position = torch.arange(
369
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
370
  )
@@ -386,7 +339,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
386
  all_self_attns = () if output_attentions else None
387
 
388
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
389
- first_rwkv_layer = True
390
  if output_hidden_states:
391
  all_hidden_states += (hidden_states,)
392
 
@@ -401,14 +353,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
401
  use_cache,
402
  cache_position,
403
  position_embeddings,
404
- attention_mask,
405
- cu_seq_lens_q,
406
- cu_seq_lens_k,
407
- max_length_q,
408
- max_length_k,
409
- cu_seqlens,
410
- self.get_v_first(decoder_layer.layer_idx,
411
- use_cache, past_key_values)
412
  )
413
  else:
414
  layer_outputs = decoder_layer(
@@ -420,14 +364,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
420
  use_cache=use_cache,
421
  cache_position=cache_position,
422
  position_embeddings=position_embeddings,
423
- sequence_mask=attention_mask,
424
- cu_seq_lens_q=cu_seq_lens_q,
425
- cu_seq_lens_k=cu_seq_lens_k,
426
- max_length_q=max_length_q,
427
- max_length_k=max_length_k,
428
- cu_seqlens=cu_seqlens,
429
- v_first=self.get_v_first(
430
- decoder_layer.layer_idx, use_cache, past_key_values)
431
  )
432
 
433
  hidden_states = layer_outputs[0]
@@ -435,14 +372,6 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
435
  if output_attentions:
436
  all_self_attns += (layer_outputs[1],)
437
 
438
- if first_rwkv_layer is True and decoder_layer.is_rwkv:
439
- v_first = layer_outputs[-1]
440
- if use_cache:
441
- past_key_values.update_v_first(v_first)
442
- else:
443
- self.register_buffer('v_first', v_first)
444
- first_rwkv_layer = False
445
-
446
  hidden_states = self.norm(hidden_states)
447
 
448
  # add hidden states from the last decoder layer
@@ -473,8 +402,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
473
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
474
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
475
  # to infer the attention mask.
476
- past_seen_tokens = past_key_values.get_seq_length(
477
- ) if past_key_values is not None else 0
478
  using_static_cache = isinstance(past_key_values, StaticCache)
479
 
480
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
@@ -519,8 +447,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
519
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
520
  # Details: https://github.com/pytorch/pytorch/issues/110213
521
  min_dtype = torch.finfo(dtype).min
522
- causal_mask = AttentionMaskConverter._unmask_unattended(
523
- causal_mask, min_dtype)
524
 
525
  return causal_mask
526
 
@@ -563,20 +490,16 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
563
  else:
564
  min_dtype = torch.finfo(dtype).min
565
  causal_mask = torch.full(
566
- (sequence_length,
567
- target_length), fill_value=min_dtype, dtype=dtype, device=device
568
  )
569
  if sequence_length != 1:
570
  causal_mask = torch.triu(causal_mask, diagonal=1)
571
- causal_mask *= torch.arange(target_length,
572
- device=device) > cache_position.reshape(-1, 1)
573
- causal_mask = causal_mask[None, None,
574
- :, :].expand(batch_size, 1, -1, -1)
575
  if attention_mask is not None:
576
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
577
  mask_length = attention_mask.shape[-1]
578
- padding_mask = causal_mask[:, :, :,
579
- :mask_length] + attention_mask[:, None, None, :]
580
  padding_mask = padding_mask == 0
581
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
582
  padding_mask, min_dtype
@@ -585,9 +508,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
585
  return causal_mask
586
 
587
 
588
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
589
- ...
590
-
591
 
592
  class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
593
  _tied_weights_keys = ["lm_head.weight"]
@@ -597,8 +518,7 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
597
  super().__init__(config)
598
  self.model = RwkvHybridModel(config)
599
  self.vocab_size = config.vocab_size
600
- self.lm_head = nn.Linear(
601
- config.hidden_size, config.vocab_size, bias=False)
602
 
603
  # Initialize weights and apply final processing
604
  self.post_init()
@@ -628,8 +548,7 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
628
  input_ids: torch.LongTensor = None,
629
  attention_mask: Optional[torch.Tensor] = None,
630
  position_ids: Optional[torch.LongTensor] = None,
631
- past_key_values: Optional[Union[Cache,
632
- List[torch.FloatTensor]]] = None,
633
  inputs_embeds: Optional[torch.FloatTensor] = None,
634
  labels: Optional[torch.LongTensor] = None,
635
  use_cache: Optional[bool] = None,
@@ -692,15 +611,12 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
692
  )
693
 
694
  hidden_states = outputs[0]
695
- # Only compute necessary logits,
696
- # and do not upcast them to float if we are not computing the loss
697
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
698
 
699
  loss = None
700
  if labels is not None:
701
- loss = self.loss_function(
702
- logits=logits, labels=labels,
703
- vocab_size=self.config.vocab_size, **kwargs)
704
 
705
  if not return_dict:
706
  output = (logits,) + outputs[1:]
@@ -713,3 +629,4 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
713
  hidden_states=outputs.hidden_states,
714
  attentions=outputs.attentions,
715
  )
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  from transformers.cache_utils import Cache
6
 
7
+ from transformers.activations import ACT2FN
8
+ from transformers.cache_utils import Cache, StaticCache
9
  from .hybrid_cache import HybridCache
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
14
 
15
+ from transformers.modeling_outputs import (
16
  BaseModelOutputWithPast,
17
  CausalLMOutputWithPast,
18
  )
19
+ from transformers.processing_utils import Unpack
20
+ from transformers.utils import (
21
  LossKwargs,
22
  add_start_docstrings,
23
  add_start_docstrings_to_model_forward,
 
28
  from .wkv import Rwkv7Attention, Rwkv6Attention
29
  from .configuration_rwkv_hybrid import RwkvHybridConfig
30
 
31
+ from transformers.models.qwen2.modeling_qwen2 import (Qwen2MLP,
32
+ Qwen2RMSNorm,
33
+ Qwen2RotaryEmbedding,
34
  Qwen2Attention)
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
  _CONFIG_FOR_DOC = "RwkvHybridConfig"
39
 
 
40
  class RwkvHybridDecoderLayer(nn.Module):
41
+ def __init__(self, config: RwkvHybridConfig, layer_idx: int, update_v_first, get_v_first):
42
  super().__init__()
43
  self.hidden_size = config.hidden_size
44
 
45
  self.is_rwkv = True if layer_idx in config.wkv_layers else False
46
  if self.is_rwkv:
47
  if config.wkv_version == 7:
48
+ self.self_attn = Rwkv7Attention(args=config, layer_id=layer_idx,
49
+ update_v_first=update_v_first,
50
+ get_v_first=get_v_first)
51
  elif config.wkv_version == 6:
52
+ self.self_attn = Rwkv6Attention(args=config, layer_id=layer_idx,
53
+ update_v_first=update_v_first,
54
+ get_v_first=get_v_first)
55
  else:
56
  raise NotImplementedError
57
+ elif not self.is_rwkv:
58
  self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
59
+ else:
60
+ self.self_attn = None
61
+ raise NotImplementedError
62
 
63
  self.mlp = Qwen2MLP(config)
64
  self.input_layernorm = Qwen2RMSNorm(
65
  config.hidden_size, eps=config.rms_norm_eps)
66
  self.post_attention_layernorm = Qwen2RMSNorm(
67
+ config.hidden_size, eps=config.rms_norm_eps)
 
68
 
69
+
70
  def forward(
71
  self,
72
  hidden_states: torch.Tensor,
73
  attention_mask: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.LongTensor] = None,
75
  past_key_value: Optional[Cache] = None,
76
  output_attentions: Optional[bool] = False,
77
  use_cache: Optional[bool] = False,
78
+ cache_position: Optional[torch.LongTensor] = None,
79
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
 
 
 
 
 
80
  **kwargs,
81
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
82
  residual = hidden_states
83
 
84
  hidden_states = self.input_layernorm(hidden_states)
85
 
86
  # RWKV attention
87
+ hidden_states, self_attn_weights = self.self_attn(
88
+ hidden_states=hidden_states,
89
+ attention_mask=attention_mask,
90
+ position_ids=position_ids,
91
+ past_key_value=past_key_value,
92
+ output_attentions=output_attentions,
93
+ use_cache=use_cache,
94
+ cache_position=cache_position,
95
+ position_embeddings=position_embeddings,
96
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  hidden_states = residual + hidden_states
98
 
99
  # Fully Connected
 
106
  if output_attentions:
107
  outputs += (self_attn_weights,)
108
 
 
 
 
109
  return outputs
110
 
 
111
  RWKV_HYBRID_START_DOCSTRING = r"""
112
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
113
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
124
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
125
  """
126
 
 
127
  @add_start_docstrings(
128
  "The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
129
  RWKV_HYBRID_START_DOCSTRING,
 
146
  if module.padding_idx is not None:
147
  module.weight.data[module.padding_idx].zero_()
148
 
 
149
  RWKV_HYBRID_INPUTS_DOCSTRING = r"""
150
  Args:
151
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 
238
  self.padding_idx = config.pad_token_id
239
  self.vocab_size = config.vocab_size
240
 
241
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
242
  self.thread_local = threading.local()
243
  self.thread_local.v_first = None
244
  self.layers = nn.ModuleList(
245
+ [RwkvHybridDecoderLayer(config, layer_idx, self.update_v_first, self.get_v_first) for layer_idx in range(config.num_hidden_layers)]
 
246
  )
247
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
 
266
  for layer in self.layers:
267
  layer.self_attn.time_mixer.post_init()
268
 
269
+ def update_v_first(self, new_v_first):
270
+ """Callback function to update v_first in HybridModel."""
271
+ self.thread_local.v_first = new_v_first
272
+
273
+ def get_v_first(self):
274
+ return self.thread_local.v_first
275
+
276
  def get_input_embeddings(self):
277
  return self.embed_tokens
278
 
279
  def set_input_embeddings(self, value):
280
  self.embed_tokens = value
281
 
 
 
 
 
 
 
 
 
282
  @add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
283
  def forward(
284
  self,
 
292
  output_hidden_states: Optional[bool] = None,
293
  return_dict: Optional[bool] = None,
294
  cache_position: Optional[torch.LongTensor] = None,
295
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
 
 
 
 
 
296
  ) -> Union[Tuple, BaseModelOutputWithPast]:
297
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
298
  output_hidden_states = (
 
302
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
303
 
304
  if (input_ids is None) ^ (inputs_embeds is not None):
305
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
306
 
307
  if self.gradient_checkpointing and self.training and use_cache:
308
  logger.warning_once(
 
317
  past_key_values = HybridCache()
318
 
319
  if cache_position is None:
320
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
321
  cache_position = torch.arange(
322
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
323
  )
 
339
  all_self_attns = () if output_attentions else None
340
 
341
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
342
  if output_hidden_states:
343
  all_hidden_states += (hidden_states,)
344
 
 
353
  use_cache,
354
  cache_position,
355
  position_embeddings,
 
 
 
 
 
 
 
 
356
  )
357
  else:
358
  layer_outputs = decoder_layer(
 
364
  use_cache=use_cache,
365
  cache_position=cache_position,
366
  position_embeddings=position_embeddings,
367
+ **flash_attn_kwargs,
 
 
 
 
 
 
 
368
  )
369
 
370
  hidden_states = layer_outputs[0]
 
372
  if output_attentions:
373
  all_self_attns += (layer_outputs[1],)
374
 
 
 
 
 
 
 
 
 
375
  hidden_states = self.norm(hidden_states)
376
 
377
  # add hidden states from the last decoder layer
 
402
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
403
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
404
  # to infer the attention mask.
405
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
406
  using_static_cache = isinstance(past_key_values, StaticCache)
407
 
408
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
 
447
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
448
  # Details: https://github.com/pytorch/pytorch/issues/110213
449
  min_dtype = torch.finfo(dtype).min
450
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
451
 
452
  return causal_mask
453
 
 
490
  else:
491
  min_dtype = torch.finfo(dtype).min
492
  causal_mask = torch.full(
493
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
 
494
  )
495
  if sequence_length != 1:
496
  causal_mask = torch.triu(causal_mask, diagonal=1)
497
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
498
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 
 
499
  if attention_mask is not None:
500
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
501
  mask_length = attention_mask.shape[-1]
502
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
503
  padding_mask = padding_mask == 0
504
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
505
  padding_mask, min_dtype
 
508
  return causal_mask
509
 
510
 
511
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
 
 
512
 
513
  class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
514
  _tied_weights_keys = ["lm_head.weight"]
 
518
  super().__init__(config)
519
  self.model = RwkvHybridModel(config)
520
  self.vocab_size = config.vocab_size
521
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
522
 
523
  # Initialize weights and apply final processing
524
  self.post_init()
 
548
  input_ids: torch.LongTensor = None,
549
  attention_mask: Optional[torch.Tensor] = None,
550
  position_ids: Optional[torch.LongTensor] = None,
551
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
 
552
  inputs_embeds: Optional[torch.FloatTensor] = None,
553
  labels: Optional[torch.LongTensor] = None,
554
  use_cache: Optional[bool] = None,
 
611
  )
612
 
613
  hidden_states = outputs[0]
614
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
 
615
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
616
 
617
  loss = None
618
  if labels is not None:
619
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
620
 
621
  if not return_dict:
622
  output = (logits,) + outputs[1:]
 
629
  hidden_states=outputs.hidden_states,
630
  attentions=outputs.attentions,
631
  )
632
+
wkv.py CHANGED
@@ -279,8 +279,7 @@ class Rwkv_Tmix_x070(nn.Module):
279
  self.a0 + (xa @ self.a1) @ self.a2
280
  ) # a is "in-context learning rate"
281
  if self.args.wkv_has_gate:
282
- g_delta = torch.sigmoid(xg @ self.g1) @ self.g2
283
- g = 1.0 + g_delta
284
  kk = k * self.k_k
285
  kk = F.normalize(kk.view(B, T, self.n_head, -1),
286
  p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
 
279
  self.a0 + (xa @ self.a1) @ self.a2
280
  ) # a is "in-context learning rate"
281
  if self.args.wkv_has_gate:
282
+ g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
 
283
  kk = k * self.k_k
284
  kk = F.normalize(kk.view(B, T, self.n_head, -1),
285
  p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)