# coding=utf-8 # Copyright 2025 Charles O. Goddard, The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # The following monkeypatches were applied by Doctor Shotgun: # # Liger Kernel (https://github.com/linkedin/Liger-Kernel): # 1. Liger RMSNorm # 2. Liger RoPE # 3. Liger SwiGLUMLP # # Cut Cross-Entropy (https://github.com/apple/ml-cross-entropy): # 1. Cut Cross-Entropy """PyTorch Qwen3 model with shared expert support.""" from typing import List, Optional, Union import torch from torch import nn import torch.nn.functional as F # CCE Patch # from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT from cut_cross_entropy.transformers.utils import ( PatchOptions, apply_lce, ) _PATCH_OPTS = PatchOptions( impl=LCE_IMPL_DEFAULT, reduction="mean", filter_eps="auto", accum_e_fp32=False, accum_c_fp32=False, filter_e_grad=True, filter_c_grad=True, train_only=False, ) # CCE Patch # # Liger Patch # from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP from liger_kernel.transformers.rope import liger_rotary_pos_emb import transformers.models.qwen3_moe.modeling_qwen3_moe transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb # Liger Patch # from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) from transformers.activations import ACT2FN from transformers.utils import logging from transformers.models.mixtral.modeling_mixtral import ( load_balancing_loss_func, ) from transformers.models.qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeMLP, Qwen3MoeRMSNorm, Qwen3MoeAttention, Qwen3MoeDecoderLayer, Qwen3MoeModel, Qwen3MoeForCausalLM, ) from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig import scattermoe logger = logging.get_logger(__name__) class Qwen3SharedMoeSparseMoeBlock(nn.Module): def __init__(self, config: Qwen3SharedMoeConfig): super().__init__() self.config = config self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) if config.shared_expert_intermediate_size is not None: self.shared_expert = Qwen3MoeMLP( config, intermediate_size=config.shared_expert_intermediate_size ) else: self.shared_expert = None self.moe_mlp = scattermoe.mlp.GLUMLP( input_size=self.config.hidden_size, hidden_size=self.config.moe_intermediate_size, num_experts=self.config.num_experts, top_k=self.config.num_experts_per_tok, activation=ACT2FN[config.hidden_act], ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # handling of gate/router logits copied from Qwen3MoeSparseMoeBlock batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk( routing_weights, self.config.num_experts_per_tok, dim=-1 ) if self.config.norm_topk_prob: # only diff with mixtral sparse moe block! routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) # modified here to use scattermoe + shared_expert hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts) if self.shared_expert is not None: shared_res = self.shared_expert(hidden_states) res = hs_0 + shared_res else: res = hs_0 res = res.reshape(batch_size, sequence_length, hidden_dim) return res, router_logits class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module): def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int): super().__init__(config, layer_idx) self.hidden_size = config.hidden_size self.self_attn = Qwen3MoeAttention(config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3SharedMoeSparseMoeBlock(config) else: self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) self.input_layernorm = Qwen3MoeRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Qwen3MoeRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) class Qwen3SharedMoeModel(Qwen3MoeModel): config_class = Qwen3SharedMoeConfig def __init__(self, config: Qwen3SharedMoeConfig): super().__init__(config) self.layers = nn.ModuleList( [ Qwen3SharedMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM): config_class = Qwen3SharedMoeConfig def __init__(self, config): super().__init__(config) self.model = Qwen3SharedMoeModel(config) self.num_experts = config.num_experts # CCE Patch # def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> MoeCausalLMOutputWithPast: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state if hidden_states is None: raise ValueError("hidden_states is None") loss = None logits = None # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None loss = apply_lce( hidden_states[:, slice_indices, :], self.lm_head.weight, labels, _PATCH_OPTS, **kwargs, ) else: logits = self.lm_head(hidden_states[:, slice_indices, :]) if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to( loss.device ) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) # CCE Patch #