from typing import Optional from transformers import AutoConfig, Gemma3TextConfig from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging from transformers.models.siglip import SiglipVisionConfig logger = logging.get_logger(__name__) class AudioConfig(PretrainedConfig): model_type = "gemma3_audio" def __init__( self, input_size=80, attention_dim=1024, attention_heads=16, num_blocks=24, linear_units=1536, dropout_rate=0.0, kernel_size=3, ext_pw_kernel_size=1, ext_pw_out_channel=1024, depthwise_seperable_out_channel=1024, depthwise_multiplier=1, activation="swish", conv_activation="swish", conv_glu_type="swish", bias_in_glu=True, causal=True, batch_norm=False, cnn_layer_norm=True, time_reduction=8, input_layer="nemo_conv", nemo_conv_settings=None, chunk_size=-1, left_chunk=18, relative_attention_bias_args=None, activation_checkpointing=None, encoder_embedding_config=None, **kwargs ): super().__init__(**kwargs) self.input_size = input_size self.attention_dim = attention_dim self.attention_heads = attention_heads self.num_blocks = num_blocks self.linear_units = linear_units self.dropout_rate = dropout_rate self.kernel_size = kernel_size self.ext_pw_kernel_size = ext_pw_kernel_size self.ext_pw_out_channel = ext_pw_out_channel self.depthwise_seperable_out_channel = depthwise_seperable_out_channel self.depthwise_multiplier = depthwise_multiplier self.activation = activation self.conv_activation = conv_activation self.conv_glu_type = conv_glu_type self.bias_in_glu = bias_in_glu self.causal = causal self.batch_norm = batch_norm self.cnn_layer_norm = cnn_layer_norm self.time_reduction = time_reduction self.input_layer = input_layer if nemo_conv_settings is None: self.nemo_conv_settings = {"conv_channels": 1024} else: self.nemo_conv_settings = nemo_conv_settings self.chunk_size = chunk_size self.left_chunk = left_chunk if relative_attention_bias_args is None: self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500} else: self.relative_attention_bias_args = relative_attention_bias_args if activation_checkpointing is None: self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False} else: self.activation_checkpointing = activation_checkpointing if encoder_embedding_config is None: self.encoder_embedding_config = {"input_size": input_size} else: self.encoder_embedding_config = encoder_embedding_config class Gemma3MMConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the PaliGemma-2B. e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: text_config (`Union[Gemma3TextConfig, dict]`, *optional*): The config object of the text backbone. vision_config (`Union[AutoConfig, dict]`, *optional*): Custom vision config or dict. audio_config (`Union[AutoConfig, dict]`, *optional*): Custom audio config or dict. mm_tokens_per_image (`int`, *optional*, defaults to 256): The number of tokens per image embedding. boi_token_index (`int`, *optional*, defaults to 255999): The begin-of-image token index to wrap the image prompt. eoi_token_index (`int`, *optional*, defaults to 256000): The end-of-image token index to wrap the image prompt. image_token_index (`int`, *optional*, defaults to 262144): The image token index to encode the image prompt. audio_token_index (`int`, *optional*, defaults to 262145): The audio token index to encode the audio prompt. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Example: ```python >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig >>> # Initializing a Siglip-like vision config >>> vision_config = SiglipVisionConfig() >>> # Initializing a Siglip-like vision config >>> audio_config = AudioConfig() >>> # Initializing a Gemma3 Text config >>> text_config = Gemma3TextConfig() >>> # Initializing a Gemma3 gemma-3-4b style configuration >>> configuration = Gemma3Config(vision_config, text_config) >>> # Initializing a model from the gemma-3-4b style configuration >>> model = Gemma3TextConfig(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "gemma3mm" sub_configs = { "text_config": Gemma3TextConfig, "vision_config": SiglipVisionConfig, "audio_config": AudioConfig, } def __init__( self, text_config: Optional[Gemma3TextConfig] = None, vision_config: Optional[SiglipVisionConfig] = None, audio_config: Optional[AudioConfig] = None, mm_tokens_per_image: int = 256, boi_token_index: int = 255_999, eoi_token_index: int = 256_000, boa_token_index: int = 256_001, eoa_token_index: int = 256_002, image_token_index: int = 262_144, audio_token_index: int = 262_143, initializer_range: float = 0.02, **kwargs, ): if text_config is None: text_config = Gemma3TextConfig() logger.info("text_config is None, using default Gemma3TextConfig vision config.") elif isinstance(text_config, dict): text_config = Gemma3TextConfig(**text_config) if isinstance(vision_config, dict): vision_config = SiglipVisionConfig(**vision_config) else: vision_config = SiglipVisionConfig() logger.info( "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " "to text tasks." ) if isinstance(audio_config, dict): audio_config = AudioConfig(**audio_config) else: audio_config = AudioConfig() logger.info( "audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited " "to text tasks." ) self.text_config = text_config self.vision_config = vision_config self.audio_config = audio_config self.mm_tokens_per_image = mm_tokens_per_image self.boi_token_index = boi_token_index self.eoi_token_index = eoi_token_index self.boa_token_index = boa_token_index self.eoa_token_index = eoa_token_index self.image_token_index = image_token_index self.audio_token_index = audio_token_index self.initializer_range = initializer_range super().__init__(**kwargs)