zhiyuan8 commited on
Commit
3cdbb2d
·
verified ·
1 Parent(s): 4b2e602

Update modeling_rwkv_hybrid.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv_hybrid.py +135 -51
modeling_rwkv_hybrid.py CHANGED
@@ -37,63 +37,95 @@ logger = logging.get_logger(__name__)
37
 
38
  _CONFIG_FOR_DOC = "RwkvHybridConfig"
39
 
 
40
  class RwkvHybridDecoderLayer(nn.Module):
41
- def __init__(self, config: RwkvHybridConfig, layer_idx: int, update_v_first, get_v_first):
42
  super().__init__()
43
  self.hidden_size = config.hidden_size
44
 
45
  self.is_rwkv = True if layer_idx in config.wkv_layers else False
46
  if self.is_rwkv:
47
  if config.wkv_version == 7:
48
- self.self_attn = Rwkv7Attention(args=config, layer_id=layer_idx,
49
- update_v_first=update_v_first,
50
- get_v_first=get_v_first)
51
  elif config.wkv_version == 6:
52
- self.self_attn = Rwkv6Attention(args=config, layer_id=layer_idx,
53
- update_v_first=update_v_first,
54
- get_v_first=get_v_first)
55
  else:
56
  raise NotImplementedError
57
- elif not self.is_rwkv:
58
- self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
59
  else:
60
- self.self_attn = None
61
- raise NotImplementedError
62
 
63
  self.mlp = Qwen2MLP(config)
64
  self.input_layernorm = Qwen2RMSNorm(
65
  config.hidden_size, eps=config.rms_norm_eps)
66
  self.post_attention_layernorm = Qwen2RMSNorm(
67
- config.hidden_size, eps=config.rms_norm_eps)
 
68
 
69
-
70
  def forward(
71
  self,
72
  hidden_states: torch.Tensor,
73
  attention_mask: Optional[torch.Tensor] = None,
74
- position_ids: Optional[torch.LongTensor] = None,
75
  past_key_value: Optional[Cache] = None,
76
  output_attentions: Optional[bool] = False,
77
  use_cache: Optional[bool] = False,
78
- cache_position: Optional[torch.LongTensor] = None,
79
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
 
 
 
 
 
80
  **kwargs,
81
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
82
  residual = hidden_states
83
 
84
  hidden_states = self.input_layernorm(hidden_states)
85
 
86
  # RWKV attention
87
- hidden_states, self_attn_weights = self.self_attn(
88
- hidden_states=hidden_states,
89
- attention_mask=attention_mask,
90
- position_ids=position_ids,
91
- past_key_value=past_key_value,
92
- output_attentions=output_attentions,
93
- use_cache=use_cache,
94
- cache_position=cache_position,
95
- position_embeddings=position_embeddings,
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  hidden_states = residual + hidden_states
98
 
99
  # Fully Connected
@@ -106,8 +138,12 @@ class RwkvHybridDecoderLayer(nn.Module):
106
  if output_attentions:
107
  outputs += (self_attn_weights,)
108
 
 
 
 
109
  return outputs
110
 
 
111
  RWKV_HYBRID_START_DOCSTRING = r"""
112
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
113
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -124,6 +160,7 @@ RWKV_HYBRID_START_DOCSTRING = r"""
124
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
125
  """
126
 
 
127
  @add_start_docstrings(
128
  "The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
129
  RWKV_HYBRID_START_DOCSTRING,
@@ -146,6 +183,7 @@ class RwkvHybridPreTrainedModel(PreTrainedModel):
146
  if module.padding_idx is not None:
147
  module.weight.data[module.padding_idx].zero_()
148
 
 
149
  RWKV_HYBRID_INPUTS_DOCSTRING = r"""
150
  Args:
151
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -238,11 +276,13 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
238
  self.padding_idx = config.pad_token_id
239
  self.vocab_size = config.vocab_size
240
 
241
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
242
  self.thread_local = threading.local()
243
  self.thread_local.v_first = None
244
  self.layers = nn.ModuleList(
245
- [RwkvHybridDecoderLayer(config, layer_idx, self.update_v_first, self.get_v_first) for layer_idx in range(config.num_hidden_layers)]
 
246
  )
247
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
@@ -266,19 +306,20 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
266
  for layer in self.layers:
267
  layer.self_attn.time_mixer.post_init()
268
 
269
- def update_v_first(self, new_v_first):
270
- """Callback function to update v_first in HybridModel."""
271
- self.thread_local.v_first = new_v_first
272
-
273
- def get_v_first(self):
274
- return self.thread_local.v_first
275
-
276
  def get_input_embeddings(self):
277
  return self.embed_tokens
278
 
279
  def set_input_embeddings(self, value):
280
  self.embed_tokens = value
281
 
 
 
 
 
 
 
 
 
282
  @add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
283
  def forward(
284
  self,
@@ -292,7 +333,12 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
292
  output_hidden_states: Optional[bool] = None,
293
  return_dict: Optional[bool] = None,
294
  cache_position: Optional[torch.LongTensor] = None,
295
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
 
 
 
 
 
296
  ) -> Union[Tuple, BaseModelOutputWithPast]:
297
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
298
  output_hidden_states = (
@@ -302,7 +348,8 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
302
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
303
 
304
  if (input_ids is None) ^ (inputs_embeds is not None):
305
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
306
 
307
  if self.gradient_checkpointing and self.training and use_cache:
308
  logger.warning_once(
@@ -317,7 +364,8 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
317
  past_key_values = HybridCache()
318
 
319
  if cache_position is None:
320
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
321
  cache_position = torch.arange(
322
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
323
  )
@@ -339,6 +387,7 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
339
  all_self_attns = () if output_attentions else None
340
 
341
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
342
  if output_hidden_states:
343
  all_hidden_states += (hidden_states,)
344
 
@@ -353,6 +402,14 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
353
  use_cache,
354
  cache_position,
355
  position_embeddings,
 
 
 
 
 
 
 
 
356
  )
357
  else:
358
  layer_outputs = decoder_layer(
@@ -364,7 +421,14 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
364
  use_cache=use_cache,
365
  cache_position=cache_position,
366
  position_embeddings=position_embeddings,
367
- **flash_attn_kwargs,
 
 
 
 
 
 
 
368
  )
369
 
370
  hidden_states = layer_outputs[0]
@@ -372,6 +436,14 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
372
  if output_attentions:
373
  all_self_attns += (layer_outputs[1],)
374
 
 
 
 
 
 
 
 
 
375
  hidden_states = self.norm(hidden_states)
376
 
377
  # add hidden states from the last decoder layer
@@ -402,7 +474,8 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
402
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
403
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
404
  # to infer the attention mask.
405
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
406
  using_static_cache = isinstance(past_key_values, StaticCache)
407
 
408
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
@@ -447,7 +520,8 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
447
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
448
  # Details: https://github.com/pytorch/pytorch/issues/110213
449
  min_dtype = torch.finfo(dtype).min
450
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
451
 
452
  return causal_mask
453
 
@@ -490,16 +564,20 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
490
  else:
491
  min_dtype = torch.finfo(dtype).min
492
  causal_mask = torch.full(
493
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
 
494
  )
495
  if sequence_length != 1:
496
  causal_mask = torch.triu(causal_mask, diagonal=1)
497
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
498
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
 
 
499
  if attention_mask is not None:
500
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
501
  mask_length = attention_mask.shape[-1]
502
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
503
  padding_mask = padding_mask == 0
504
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
505
  padding_mask, min_dtype
@@ -508,7 +586,9 @@ class RwkvHybridModel(RwkvHybridPreTrainedModel):
508
  return causal_mask
509
 
510
 
511
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
 
 
512
 
513
  class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
514
  _tied_weights_keys = ["lm_head.weight"]
@@ -518,7 +598,8 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
518
  super().__init__(config)
519
  self.model = RwkvHybridModel(config)
520
  self.vocab_size = config.vocab_size
521
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
522
 
523
  # Initialize weights and apply final processing
524
  self.post_init()
@@ -548,7 +629,8 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
548
  input_ids: torch.LongTensor = None,
549
  attention_mask: Optional[torch.Tensor] = None,
550
  position_ids: Optional[torch.LongTensor] = None,
551
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
 
552
  inputs_embeds: Optional[torch.FloatTensor] = None,
553
  labels: Optional[torch.LongTensor] = None,
554
  use_cache: Optional[bool] = None,
@@ -611,12 +693,15 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
611
  )
612
 
613
  hidden_states = outputs[0]
614
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
 
615
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
616
 
617
  loss = None
618
  if labels is not None:
619
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
620
 
621
  if not return_dict:
622
  output = (logits,) + outputs[1:]
@@ -629,4 +714,3 @@ class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
629
  hidden_states=outputs.hidden_states,
630
  attentions=outputs.attentions,
631
  )
632
-
 
37
 
38
  _CONFIG_FOR_DOC = "RwkvHybridConfig"
39
 
40
+
41
  class RwkvHybridDecoderLayer(nn.Module):
42
+ def __init__(self, config: RwkvHybridConfig, layer_idx: int):
43
  super().__init__()
44
  self.hidden_size = config.hidden_size
45
 
46
  self.is_rwkv = True if layer_idx in config.wkv_layers else False
47
  if self.is_rwkv:
48
  if config.wkv_version == 7:
49
+ self.self_attn = Rwkv7Attention(
50
+ args=config, layer_id=layer_idx)
 
51
  elif config.wkv_version == 6:
52
+ self.self_attn = Rwkv6Attention(
53
+ args=config, layer_id=layer_idx)
 
54
  else:
55
  raise NotImplementedError
 
 
56
  else:
57
+ self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
 
58
 
59
  self.mlp = Qwen2MLP(config)
60
  self.input_layernorm = Qwen2RMSNorm(
61
  config.hidden_size, eps=config.rms_norm_eps)
62
  self.post_attention_layernorm = Qwen2RMSNorm(
63
+ config.hidden_size, eps=config.rms_norm_eps)
64
+ self.layer_idx = layer_idx
65
 
 
66
  def forward(
67
  self,
68
  hidden_states: torch.Tensor,
69
  attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.Tensor] = None,
71
  past_key_value: Optional[Cache] = None,
72
  output_attentions: Optional[bool] = False,
73
  use_cache: Optional[bool] = False,
74
+ cache_position: Optional[torch.Tensor] = None,
75
+ position_embeddings: Optional[torch.Tensor] = None,
76
+ sequence_mask: Optional[torch.Tensor] = None,
77
+ cu_seq_lens_q: Optional[torch.LongTensor] = None,
78
+ cu_seq_lens_k: Optional[torch.LongTensor] = None,
79
+ max_length_q: Optional[int] = None,
80
+ max_length_k: Optional[int] = None,
81
+ cu_seqlens: Optional[torch.LongTensor] = None,
82
+ v_first: Optional[torch.LongTensor] = None,
83
  **kwargs,
84
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
85
+
86
+ if sequence_mask is not None:
87
+ assert len(sequence_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+ hidden_states = hidden_states.mul(
93
+ sequence_mask[:, -hidden_states.shape[-2]:, None])
94
+
95
  residual = hidden_states
96
 
97
  hidden_states = self.input_layernorm(hidden_states)
98
 
99
  # RWKV attention
100
+ if self.is_rwkv:
101
+ hidden_states, self_attn_weights, v_first = self.self_attn(
102
+ hidden_states=hidden_states,
103
+ position_ids=position_ids,
104
+ past_key_value=past_key_value,
105
+ output_attentions=output_attentions,
106
+ use_cache=use_cache,
107
+ cache_position=cache_position,
108
+ position_embeddings=position_embeddings,
109
+ cu_seqlens=cu_seqlens,
110
+ v_first=v_first,
111
+ **kwargs
112
+ )
113
+ else:
114
+ hidden_states, self_attn_weights = self.self_attn(
115
+ hidden_states=hidden_states,
116
+ attention_mask=attention_mask,
117
+ position_ids=position_ids,
118
+ past_key_value=past_key_value,
119
+ output_attentions=output_attentions,
120
+ use_cache=use_cache,
121
+ cache_position=cache_position,
122
+ position_embeddings=position_embeddings,
123
+ cu_seq_lens_q=cu_seq_lens_q,
124
+ cu_seq_lens_k=cu_seq_lens_k,
125
+ max_length_q=max_length_q,
126
+ max_length_k=max_length_k,
127
+ **kwargs
128
+ )
129
  hidden_states = residual + hidden_states
130
 
131
  # Fully Connected
 
138
  if output_attentions:
139
  outputs += (self_attn_weights,)
140
 
141
+ if self.is_rwkv:
142
+ outputs += (v_first,)
143
+
144
  return outputs
145
 
146
+
147
  RWKV_HYBRID_START_DOCSTRING = r"""
148
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
149
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
160
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
161
  """
162
 
163
+
164
  @add_start_docstrings(
165
  "The bare RWKV Hybrid Model outputting raw hidden-states without any specific head on top.",
166
  RWKV_HYBRID_START_DOCSTRING,
 
183
  if module.padding_idx is not None:
184
  module.weight.data[module.padding_idx].zero_()
185
 
186
+
187
  RWKV_HYBRID_INPUTS_DOCSTRING = r"""
188
  Args:
189
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 
276
  self.padding_idx = config.pad_token_id
277
  self.vocab_size = config.vocab_size
278
 
279
+ self.embed_tokens = nn.Embedding(
280
+ config.vocab_size, config.hidden_size, self.padding_idx)
281
  self.thread_local = threading.local()
282
  self.thread_local.v_first = None
283
  self.layers = nn.ModuleList(
284
+ [RwkvHybridDecoderLayer(config, layer_idx)
285
+ for layer_idx in range(config.num_hidden_layers)]
286
  )
287
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
288
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
 
306
  for layer in self.layers:
307
  layer.self_attn.time_mixer.post_init()
308
 
 
 
 
 
 
 
 
309
  def get_input_embeddings(self):
310
  return self.embed_tokens
311
 
312
  def set_input_embeddings(self, value):
313
  self.embed_tokens = value
314
 
315
+ def get_v_first(self, layer_idx: int, use_cache: bool, past_key_value: HybridCache):
316
+ if layer_idx == 0:
317
+ return None
318
+
319
+ if use_cache:
320
+ return past_key_value.get_v_first()
321
+ return self.v_first
322
+
323
  @add_start_docstrings_to_model_forward(RWKV_HYBRID_INPUTS_DOCSTRING)
324
  def forward(
325
  self,
 
333
  output_hidden_states: Optional[bool] = None,
334
  return_dict: Optional[bool] = None,
335
  cache_position: Optional[torch.LongTensor] = None,
336
+ cu_seq_lens_q: Optional[torch.LongTensor] = None,
337
+ cu_seq_lens_k: Optional[torch.LongTensor] = None,
338
+ max_length_q: Optional[int] = None,
339
+ max_length_k: Optional[int] = None,
340
+ cu_seqlens: Optional[torch.LongTensor] = None,
341
+ **kwargs,
342
  ) -> Union[Tuple, BaseModelOutputWithPast]:
343
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
344
  output_hidden_states = (
 
348
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
 
350
  if (input_ids is None) ^ (inputs_embeds is not None):
351
+ raise ValueError(
352
+ "You must specify exactly one of input_ids or inputs_embeds")
353
 
354
  if self.gradient_checkpointing and self.training and use_cache:
355
  logger.warning_once(
 
364
  past_key_values = HybridCache()
365
 
366
  if cache_position is None:
367
+ past_seen_tokens = past_key_values.get_seq_length(
368
+ ) if past_key_values is not None else 0
369
  cache_position = torch.arange(
370
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
371
  )
 
387
  all_self_attns = () if output_attentions else None
388
 
389
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
390
+ first_rwkv_layer = True
391
  if output_hidden_states:
392
  all_hidden_states += (hidden_states,)
393
 
 
402
  use_cache,
403
  cache_position,
404
  position_embeddings,
405
+ attention_mask,
406
+ cu_seq_lens_q,
407
+ cu_seq_lens_k,
408
+ max_length_q,
409
+ max_length_k,
410
+ cu_seqlens,
411
+ self.get_v_first(decoder_layer.layer_idx,
412
+ use_cache, past_key_values)
413
  )
414
  else:
415
  layer_outputs = decoder_layer(
 
421
  use_cache=use_cache,
422
  cache_position=cache_position,
423
  position_embeddings=position_embeddings,
424
+ sequence_mask=attention_mask,
425
+ cu_seq_lens_q=cu_seq_lens_q,
426
+ cu_seq_lens_k=cu_seq_lens_k,
427
+ max_length_q=max_length_q,
428
+ max_length_k=max_length_k,
429
+ cu_seqlens=cu_seqlens,
430
+ v_first=self.get_v_first(
431
+ decoder_layer.layer_idx, use_cache, past_key_values)
432
  )
433
 
434
  hidden_states = layer_outputs[0]
 
436
  if output_attentions:
437
  all_self_attns += (layer_outputs[1],)
438
 
439
+ if first_rwkv_layer is True and decoder_layer.is_rwkv:
440
+ v_first = layer_outputs[-1]
441
+ if use_cache:
442
+ past_key_values.update_v_first(v_first)
443
+ else:
444
+ self.register_buffer('v_first', v_first)
445
+ first_rwkv_layer = False
446
+
447
  hidden_states = self.norm(hidden_states)
448
 
449
  # add hidden states from the last decoder layer
 
474
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
475
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
476
  # to infer the attention mask.
477
+ past_seen_tokens = past_key_values.get_seq_length(
478
+ ) if past_key_values is not None else 0
479
  using_static_cache = isinstance(past_key_values, StaticCache)
480
 
481
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
 
520
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
521
  # Details: https://github.com/pytorch/pytorch/issues/110213
522
  min_dtype = torch.finfo(dtype).min
523
+ causal_mask = AttentionMaskConverter._unmask_unattended(
524
+ causal_mask, min_dtype)
525
 
526
  return causal_mask
527
 
 
564
  else:
565
  min_dtype = torch.finfo(dtype).min
566
  causal_mask = torch.full(
567
+ (sequence_length,
568
+ target_length), fill_value=min_dtype, dtype=dtype, device=device
569
  )
570
  if sequence_length != 1:
571
  causal_mask = torch.triu(causal_mask, diagonal=1)
572
+ causal_mask *= torch.arange(target_length,
573
+ device=device) > cache_position.reshape(-1, 1)
574
+ causal_mask = causal_mask[None, None,
575
+ :, :].expand(batch_size, 1, -1, -1)
576
  if attention_mask is not None:
577
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
578
  mask_length = attention_mask.shape[-1]
579
+ padding_mask = causal_mask[:, :, :,
580
+ :mask_length] + attention_mask[:, None, None, :]
581
  padding_mask = padding_mask == 0
582
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
583
  padding_mask, min_dtype
 
586
  return causal_mask
587
 
588
 
589
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
590
+ ...
591
+
592
 
593
  class RwkvHybridForCausalLM(RwkvHybridPreTrainedModel, GenerationMixin):
594
  _tied_weights_keys = ["lm_head.weight"]
 
598
  super().__init__(config)
599
  self.model = RwkvHybridModel(config)
600
  self.vocab_size = config.vocab_size
601
+ self.lm_head = nn.Linear(
602
+ config.hidden_size, config.vocab_size, bias=False)
603
 
604
  # Initialize weights and apply final processing
605
  self.post_init()
 
629
  input_ids: torch.LongTensor = None,
630
  attention_mask: Optional[torch.Tensor] = None,
631
  position_ids: Optional[torch.LongTensor] = None,
632
+ past_key_values: Optional[Union[Cache,
633
+ List[torch.FloatTensor]]] = None,
634
  inputs_embeds: Optional[torch.FloatTensor] = None,
635
  labels: Optional[torch.LongTensor] = None,
636
  use_cache: Optional[bool] = None,
 
693
  )
694
 
695
  hidden_states = outputs[0]
696
+ # Only compute necessary logits,
697
+ # and do not upcast them to float if we are not computing the loss
698
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
699
 
700
  loss = None
701
  if labels is not None:
702
+ loss = self.loss_function(
703
+ logits=logits, labels=labels,
704
+ vocab_size=self.config.vocab_size, **kwargs)
705
 
706
  if not return_dict:
707
  output = (logits,) + outputs[1:]
 
714
  hidden_states=outputs.hidden_states,
715
  attentions=outputs.attentions,
716
  )