from typing import Optional import torch from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina from transformers import PretrainedConfig, PreTrainedModel from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel class NextDiTCrossAttnConfig(PretrainedConfig): model_type = "nextdit-crossattn" def __init__( self, input_size: int = 8, patch_size: int = 1, in_channels: int = 1792, dim: int = 1792, n_layers: int = 24, n_heads: int = 28, n_kv_heads: int = 28, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: float = 1e-5, latent_embedding_size: int = 3584, learn_sigma: bool = False, qk_norm: bool = True, _gradient_checkpointing: bool = True, **kwargs, ): super().__init__(**kwargs) self.input_size = input_size self.patch_size = patch_size self.in_channels = in_channels self.dim = dim self.n_layers = n_layers self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.multiple_of = multiple_of self.ffn_dim_multiplier = ffn_dim_multiplier self.norm_eps = norm_eps self.learn_sigma = learn_sigma self.qk_norm = qk_norm self.latent_embedding_size = latent_embedding_size self._gradient_checkpointing = _gradient_checkpointing class NextDiTCrossAttn(PreTrainedModel): config_class = NextDiTCrossAttnConfig def __init__( self, config: NextDiTCrossAttnConfig, ) -> None: super().__init__(config) assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn" self._gradient_checkpointing = config._gradient_checkpointing self.model = LuminaNextDiT2DModel( sample_size=config.input_size, patch_size=config.patch_size, in_channels=config.in_channels, hidden_size=config.dim, num_layers=config.n_layers, num_attention_heads=config.n_heads, num_kv_heads=config.n_kv_heads, multiple_of=config.multiple_of, ffn_dim_multiplier=config.ffn_dim_multiplier, norm_eps=config.norm_eps, learn_sigma=config.learn_sigma, qk_norm=config.qk_norm, cross_attention_dim=config.latent_embedding_size, ) if self._gradient_checkpointing: self.model.enable_gradient_checkpointing() # self.model.requires_grad_(False) self.freqs_cis = get_2d_rotary_pos_embed_lumina( config.dim // config.n_heads, 384, 384, ) def forward(self, x, timestep, z_latents, **kwargs): model_pred = self.model( hidden_states=x, timestep=timestep, encoder_hidden_states=z_latents, encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device), image_rotary_emb=self.freqs_cis, cross_attention_kwargs=dict(), ).sample return model_pred