Doctor-Shotgun commited on
Commit
6742e10
·
verified ·
1 Parent(s): f97eef1

Add training monkeypatches

Browse files
modeling_qwen3_shared_moe_monkeypatch_liger_cce.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Charles O. Goddard, The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # The following monkeypatches were applied by Doctor Shotgun:
17
+ #
18
+ # Liger Kernel (https://github.com/linkedin/Liger-Kernel):
19
+ # 1. Liger RMSNorm
20
+ # 2. Liger RoPE
21
+ # 3. Liger SwiGLUMLP
22
+ #
23
+ # Cut Cross-Entropy (https://github.com/apple/ml-cross-entropy):
24
+ # 1. Cut Cross-Entropy
25
+ """PyTorch Qwen3 model with shared expert support."""
26
+
27
+ from typing import List, Optional, Union
28
+
29
+ import torch
30
+ from torch import nn
31
+ import torch.nn.functional as F
32
+
33
+ # CCE Patch #
34
+ from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
35
+ from cut_cross_entropy.transformers.utils import (
36
+ PatchOptions,
37
+ apply_lce,
38
+ )
39
+ _PATCH_OPTS = PatchOptions(
40
+ impl=LCE_IMPL_DEFAULT,
41
+ reduction="mean",
42
+ filter_eps="auto",
43
+ accum_e_fp32=False,
44
+ accum_c_fp32=False,
45
+ filter_e_grad=True,
46
+ filter_c_grad=True,
47
+ train_only=False,
48
+ )
49
+ # CCE Patch #
50
+
51
+ # Liger Patch #
52
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
53
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
54
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
55
+
56
+ import transformers.models.qwen3_moe.modeling_qwen3_moe
57
+ transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
58
+ transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
59
+ transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
60
+ # Liger Patch #
61
+
62
+ from transformers.modeling_outputs import (
63
+ MoeCausalLMOutputWithPast,
64
+ MoeModelOutputWithPast,
65
+ )
66
+ from transformers.activations import ACT2FN
67
+ from transformers.utils import logging
68
+ from transformers.models.mixtral.modeling_mixtral import (
69
+ load_balancing_loss_func,
70
+ )
71
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import (
72
+ Qwen3MoeMLP,
73
+ Qwen3MoeRMSNorm,
74
+ Qwen3MoeAttention,
75
+ Qwen3MoeDecoderLayer,
76
+ Qwen3MoeModel,
77
+ Qwen3MoeForCausalLM,
78
+ )
79
+ from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig
80
+
81
+ import scattermoe
82
+
83
+
84
+ logger = logging.get_logger(__name__)
85
+
86
+
87
+ class Qwen3SharedMoeSparseMoeBlock(nn.Module):
88
+ def __init__(self, config: Qwen3SharedMoeConfig):
89
+ super().__init__()
90
+ self.config = config
91
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
92
+ if config.shared_expert_intermediate_size is not None:
93
+ self.shared_expert = Qwen3MoeMLP(
94
+ config, intermediate_size=config.shared_expert_intermediate_size
95
+ )
96
+ else:
97
+ self.shared_expert = None
98
+ self.moe_mlp = scattermoe.mlp.GLUMLP(
99
+ input_size=self.config.hidden_size,
100
+ hidden_size=self.config.moe_intermediate_size,
101
+ num_experts=self.config.num_experts,
102
+ top_k=self.config.num_experts_per_tok,
103
+ activation=ACT2FN[config.hidden_act],
104
+ )
105
+
106
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
107
+ # handling of gate/router logits copied from Qwen3MoeSparseMoeBlock
108
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
109
+ hidden_states = hidden_states.view(-1, hidden_dim)
110
+ # router_logits: (batch * sequence_length, n_experts)
111
+ router_logits = self.gate(hidden_states)
112
+
113
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
114
+ routing_weights, selected_experts = torch.topk(
115
+ routing_weights, self.config.num_experts_per_tok, dim=-1
116
+ )
117
+ if self.config.norm_topk_prob: # only diff with mixtral sparse moe block!
118
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
119
+ # we cast back to the input dtype
120
+ routing_weights = routing_weights.to(hidden_states.dtype)
121
+
122
+ # modified here to use scattermoe + shared_expert
123
+ hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts)
124
+
125
+ if self.shared_expert is not None:
126
+ shared_res = self.shared_expert(hidden_states)
127
+ res = hs_0 + shared_res
128
+ else:
129
+ res = hs_0
130
+ res = res.reshape(batch_size, sequence_length, hidden_dim)
131
+ return res, router_logits
132
+
133
+
134
+ class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
135
+ def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int):
136
+ super().__init__(config, layer_idx)
137
+ self.hidden_size = config.hidden_size
138
+
139
+ self.self_attn = Qwen3MoeAttention(config, layer_idx)
140
+
141
+ if (layer_idx not in config.mlp_only_layers) and (
142
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
143
+ ):
144
+ self.mlp = Qwen3SharedMoeSparseMoeBlock(config)
145
+ else:
146
+ self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
147
+
148
+ self.input_layernorm = Qwen3MoeRMSNorm(
149
+ config.hidden_size, eps=config.rms_norm_eps
150
+ )
151
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(
152
+ config.hidden_size, eps=config.rms_norm_eps
153
+ )
154
+
155
+
156
+ class Qwen3SharedMoeModel(Qwen3MoeModel):
157
+ config_class = Qwen3SharedMoeConfig
158
+
159
+ def __init__(self, config: Qwen3SharedMoeConfig):
160
+ super().__init__(config)
161
+ self.layers = nn.ModuleList(
162
+ [
163
+ Qwen3SharedMoeDecoderLayer(config, layer_idx)
164
+ for layer_idx in range(config.num_hidden_layers)
165
+ ]
166
+ )
167
+
168
+
169
+ class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
170
+ config_class = Qwen3SharedMoeConfig
171
+
172
+ def __init__(self, config):
173
+ super().__init__(config)
174
+ self.model = Qwen3SharedMoeModel(config)
175
+ self.num_experts = config.num_experts
176
+
177
+ # CCE Patch #
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.LongTensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ position_ids: Optional[torch.LongTensor] = None,
183
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ labels: Optional[torch.LongTensor] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ output_router_logits: Optional[bool] = None,
190
+ cache_position: Optional[torch.LongTensor] = None,
191
+ logits_to_keep: Union[int, torch.Tensor] = 0,
192
+ **kwargs,
193
+ ) -> MoeCausalLMOutputWithPast:
194
+
195
+ output_attentions = (
196
+ output_attentions
197
+ if output_attentions is not None
198
+ else self.config.output_attentions
199
+ )
200
+ output_router_logits = (
201
+ output_router_logits
202
+ if output_router_logits is not None
203
+ else self.config.output_router_logits
204
+ )
205
+
206
+ output_hidden_states = (
207
+ output_hidden_states
208
+ if output_hidden_states is not None
209
+ else self.config.output_hidden_states
210
+ )
211
+
212
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
213
+ outputs: MoeModelOutputWithPast = self.model(
214
+ input_ids=input_ids,
215
+ attention_mask=attention_mask,
216
+ position_ids=position_ids,
217
+ past_key_values=past_key_values,
218
+ inputs_embeds=inputs_embeds,
219
+ use_cache=use_cache,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ output_router_logits=output_router_logits,
223
+ cache_position=cache_position,
224
+ **kwargs,
225
+ )
226
+
227
+ hidden_states = outputs.last_hidden_state
228
+
229
+ if hidden_states is None:
230
+ raise ValueError("hidden_states is None")
231
+
232
+ loss = None
233
+ logits = None
234
+
235
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
236
+ slice_indices = (
237
+ slice(-logits_to_keep, None)
238
+ if isinstance(logits_to_keep, int)
239
+ else logits_to_keep
240
+ )
241
+
242
+ if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
243
+ assert labels is not None
244
+ loss = apply_lce(
245
+ hidden_states[:, slice_indices, :],
246
+ self.lm_head.weight,
247
+ labels,
248
+ _PATCH_OPTS,
249
+ **kwargs,
250
+ )
251
+ else:
252
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
253
+
254
+ if labels is not None:
255
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
256
+
257
+ aux_loss = None
258
+ if output_router_logits:
259
+ aux_loss = load_balancing_loss_func(
260
+ outputs.router_logits,
261
+ self.num_experts,
262
+ self.num_experts_per_tok,
263
+ attention_mask,
264
+ )
265
+ if labels is not None:
266
+ loss += self.router_aux_loss_coef * aux_loss.to(
267
+ loss.device
268
+ ) # make sure to reside in the same device
269
+
270
+ return MoeCausalLMOutputWithPast(
271
+ loss=loss,
272
+ aux_loss=aux_loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ router_logits=outputs.router_logits,
278
+ )
279
+ # CCE Patch #
280
+
modeling_qwen3_shared_moe_monkeypatch_liger_flce.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Charles O. Goddard, The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # The following monkeypatches were applied by Doctor Shotgun:
17
+ #
18
+ # Liger Kernel (https://github.com/linkedin/Liger-Kernel):
19
+ # 1. Liger RMSNorm
20
+ # 2. Liger RoPE
21
+ # 3. Liger SwiGLUMLP
22
+ # 4. Liger Fused Linear Cross-Entropy
23
+ """PyTorch Qwen3 model with shared expert support."""
24
+
25
+ from typing import List, Optional, Union
26
+
27
+ import torch
28
+ from torch import nn
29
+ import torch.nn.functional as F
30
+
31
+ # Liger Patch #
32
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
34
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
35
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
36
+
37
+ import transformers.models.qwen3_moe.modeling_qwen3_moe
38
+ transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
39
+ transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
40
+ transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
41
+ # Liger Patch #
42
+
43
+ from transformers.modeling_outputs import (
44
+ MoeCausalLMOutputWithPast,
45
+ MoeModelOutputWithPast,
46
+ )
47
+ from transformers.activations import ACT2FN
48
+ from transformers.utils import logging
49
+ from transformers.models.mixtral.modeling_mixtral import (
50
+ load_balancing_loss_func,
51
+ )
52
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import (
53
+ Qwen3MoeMLP,
54
+ Qwen3MoeRMSNorm,
55
+ Qwen3MoeAttention,
56
+ Qwen3MoeDecoderLayer,
57
+ Qwen3MoeModel,
58
+ Qwen3MoeForCausalLM,
59
+ )
60
+ from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig
61
+
62
+ import scattermoe
63
+
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+
68
+ class Qwen3SharedMoeSparseMoeBlock(nn.Module):
69
+ def __init__(self, config: Qwen3SharedMoeConfig):
70
+ super().__init__()
71
+ self.config = config
72
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
73
+ if config.shared_expert_intermediate_size is not None:
74
+ self.shared_expert = Qwen3MoeMLP(
75
+ config, intermediate_size=config.shared_expert_intermediate_size
76
+ )
77
+ else:
78
+ self.shared_expert = None
79
+ self.moe_mlp = scattermoe.mlp.GLUMLP(
80
+ input_size=self.config.hidden_size,
81
+ hidden_size=self.config.moe_intermediate_size,
82
+ num_experts=self.config.num_experts,
83
+ top_k=self.config.num_experts_per_tok,
84
+ activation=ACT2FN[config.hidden_act],
85
+ )
86
+
87
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
88
+ # handling of gate/router logits copied from Qwen3MoeSparseMoeBlock
89
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
90
+ hidden_states = hidden_states.view(-1, hidden_dim)
91
+ # router_logits: (batch * sequence_length, n_experts)
92
+ router_logits = self.gate(hidden_states)
93
+
94
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
95
+ routing_weights, selected_experts = torch.topk(
96
+ routing_weights, self.config.num_experts_per_tok, dim=-1
97
+ )
98
+ if self.config.norm_topk_prob: # only diff with mixtral sparse moe block!
99
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
100
+ # we cast back to the input dtype
101
+ routing_weights = routing_weights.to(hidden_states.dtype)
102
+
103
+ # modified here to use scattermoe + shared_expert
104
+ hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts)
105
+
106
+ if self.shared_expert is not None:
107
+ shared_res = self.shared_expert(hidden_states)
108
+ res = hs_0 + shared_res
109
+ else:
110
+ res = hs_0
111
+ res = res.reshape(batch_size, sequence_length, hidden_dim)
112
+ return res, router_logits
113
+
114
+
115
+ class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
116
+ def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int):
117
+ super().__init__(config, layer_idx)
118
+ self.hidden_size = config.hidden_size
119
+
120
+ self.self_attn = Qwen3MoeAttention(config, layer_idx)
121
+
122
+ if (layer_idx not in config.mlp_only_layers) and (
123
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
124
+ ):
125
+ self.mlp = Qwen3SharedMoeSparseMoeBlock(config)
126
+ else:
127
+ self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
128
+
129
+ self.input_layernorm = Qwen3MoeRMSNorm(
130
+ config.hidden_size, eps=config.rms_norm_eps
131
+ )
132
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(
133
+ config.hidden_size, eps=config.rms_norm_eps
134
+ )
135
+
136
+
137
+ class Qwen3SharedMoeModel(Qwen3MoeModel):
138
+ config_class = Qwen3SharedMoeConfig
139
+
140
+ def __init__(self, config: Qwen3SharedMoeConfig):
141
+ super().__init__(config)
142
+ self.layers = nn.ModuleList(
143
+ [
144
+ Qwen3SharedMoeDecoderLayer(config, layer_idx)
145
+ for layer_idx in range(config.num_hidden_layers)
146
+ ]
147
+ )
148
+
149
+
150
+ class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
151
+ config_class = Qwen3SharedMoeConfig
152
+
153
+ def __init__(self, config):
154
+ super().__init__(config)
155
+ self.model = Qwen3SharedMoeModel(config)
156
+ self.num_experts = config.num_experts
157
+
158
+ # Liger Patch #
159
+ def forward(
160
+ self,
161
+ input_ids: Optional[torch.LongTensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ position_ids: Optional[torch.LongTensor] = None,
164
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
165
+ inputs_embeds: Optional[torch.FloatTensor] = None,
166
+ labels: Optional[torch.LongTensor] = None,
167
+ use_cache: Optional[bool] = None,
168
+ output_attentions: Optional[bool] = None,
169
+ output_hidden_states: Optional[bool] = None,
170
+ output_router_logits: Optional[bool] = None,
171
+ cache_position: Optional[torch.LongTensor] = None,
172
+ logits_to_keep: Union[int, torch.Tensor] = 0,
173
+ skip_logits: Optional[bool] = None,
174
+ **kwargs,
175
+ ) -> MoeCausalLMOutputWithPast:
176
+ r"""
177
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
178
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
179
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
180
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
181
+
182
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
183
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
184
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
185
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
186
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
187
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
188
+
189
+ Returns:
190
+
191
+ Example:
192
+
193
+ ```python
194
+ >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
195
+
196
+ >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
197
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
198
+
199
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
200
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
201
+
202
+ >>> # Generate
203
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
204
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
205
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
206
+ ```"""
207
+
208
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
209
+ output_router_logits = (
210
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
211
+ )
212
+
213
+ output_hidden_states = (
214
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
215
+ )
216
+
217
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
+ outputs: MoeModelOutputWithPast = self.model(
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_values=past_key_values,
223
+ inputs_embeds=inputs_embeds,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ output_router_logits=output_router_logits,
228
+ cache_position=cache_position,
229
+ **kwargs,
230
+ )
231
+
232
+ hidden_states = outputs.last_hidden_state
233
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
234
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
235
+ kept_hidden_states = hidden_states[:, slice_indices, :]
236
+
237
+ shift_labels = kwargs.pop("shift_labels", None)
238
+ logits = None
239
+ loss = None
240
+
241
+ if skip_logits is None:
242
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
243
+
244
+ if skip_logits:
245
+ loss = LigerForCausalLMLoss(
246
+ hidden_states=kept_hidden_states,
247
+ lm_head_weight=self.lm_head.weight,
248
+ labels=labels,
249
+ shift_labels=shift_labels,
250
+ hidden_size=self.config.hidden_size,
251
+ **kwargs,
252
+ )
253
+ else: # if in inference model materialize logits
254
+ logits = self.lm_head(kept_hidden_states)
255
+ if labels is not None:
256
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
257
+
258
+ aux_loss = None
259
+ if output_router_logits:
260
+ aux_loss = load_balancing_loss_func(
261
+ outputs.router_logits,
262
+ self.num_experts,
263
+ self.num_experts_per_tok,
264
+ attention_mask,
265
+ )
266
+ if labels is not None:
267
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
268
+
269
+ return MoeCausalLMOutputWithPast(
270
+ loss=loss,
271
+ aux_loss=aux_loss,
272
+ logits=logits,
273
+ past_key_values=outputs.past_key_values,
274
+ hidden_states=outputs.hidden_states,
275
+ attentions=outputs.attentions,
276
+ router_logits=outputs.router_logits,
277
+ )
278
+ # Liger Patch #
279
+