gemma-3-4b-it-speech / configuration_gemma3mm.py
junnei's picture
fix config errors
aef5652 verified
raw
history blame
8.01 kB
from typing import Optional
from transformers import AutoConfig, Gemma3TextConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers.models.siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
class AudioConfig(PretrainedConfig):
model_type = "gemma3_audio"
def __init__(
self,
input_size=80,
attention_dim=1024,
attention_heads=16,
num_blocks=24,
linear_units=1536,
dropout_rate=0.0,
kernel_size=3,
ext_pw_kernel_size=1,
ext_pw_out_channel=1024,
depthwise_seperable_out_channel=1024,
depthwise_multiplier=1,
activation="swish",
conv_activation="swish",
conv_glu_type="swish",
bias_in_glu=True,
causal=True,
batch_norm=False,
cnn_layer_norm=True,
time_reduction=8,
input_layer="nemo_conv",
nemo_conv_settings=None,
chunk_size=-1,
left_chunk=18,
relative_attention_bias_args=None,
activation_checkpointing=None,
encoder_embedding_config=None,
**kwargs
):
super().__init__(**kwargs)
self.input_size = input_size
self.attention_dim = attention_dim
self.attention_heads = attention_heads
self.num_blocks = num_blocks
self.linear_units = linear_units
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.ext_pw_kernel_size = ext_pw_kernel_size
self.ext_pw_out_channel = ext_pw_out_channel
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
self.depthwise_multiplier = depthwise_multiplier
self.activation = activation
self.conv_activation = conv_activation
self.conv_glu_type = conv_glu_type
self.bias_in_glu = bias_in_glu
self.causal = causal
self.batch_norm = batch_norm
self.cnn_layer_norm = cnn_layer_norm
self.time_reduction = time_reduction
self.input_layer = input_layer
if nemo_conv_settings is None:
self.nemo_conv_settings = {"conv_channels": 1024}
else:
self.nemo_conv_settings = nemo_conv_settings
self.chunk_size = chunk_size
self.left_chunk = left_chunk
if relative_attention_bias_args is None:
self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
else:
self.relative_attention_bias_args = relative_attention_bias_args
if activation_checkpointing is None:
self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
else:
self.activation_checkpointing = activation_checkpointing
if encoder_embedding_config is None:
self.encoder_embedding_config = {"input_size": input_size}
else:
self.encoder_embedding_config = encoder_embedding_config
class Gemma3MMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
audio_token_index (`int`, *optional*, defaults to 262145):
The audio token index to encode the audio prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Siglip-like vision config
>>> audio_config = AudioConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3mm"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
"audio_config": AudioConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
audio_config: Optional[AudioConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
boa_token_index: int = 256_001,
eoa_token_index: int = 256_002,
image_token_index: int = 262_144,
audio_token_index: int = 262_143,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
if isinstance(audio_config, dict):
audio_config = AudioConfig(**audio_config)
else:
audio_config = AudioConfig()
logger.info(
"audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.audio_config = audio_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.boa_token_index = boa_token_index
self.eoa_token_index = eoa_token_index
self.image_token_index = image_token_index
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)