File size: 8,009 Bytes
43e08ee eb143e4 43e08ee aef5652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from typing import Optional
from transformers import AutoConfig, Gemma3TextConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers.models.siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
class AudioConfig(PretrainedConfig):
model_type = "gemma3_audio"
def __init__(
self,
input_size=80,
attention_dim=1024,
attention_heads=16,
num_blocks=24,
linear_units=1536,
dropout_rate=0.0,
kernel_size=3,
ext_pw_kernel_size=1,
ext_pw_out_channel=1024,
depthwise_seperable_out_channel=1024,
depthwise_multiplier=1,
activation="swish",
conv_activation="swish",
conv_glu_type="swish",
bias_in_glu=True,
causal=True,
batch_norm=False,
cnn_layer_norm=True,
time_reduction=8,
input_layer="nemo_conv",
nemo_conv_settings=None,
chunk_size=-1,
left_chunk=18,
relative_attention_bias_args=None,
activation_checkpointing=None,
encoder_embedding_config=None,
**kwargs
):
super().__init__(**kwargs)
self.input_size = input_size
self.attention_dim = attention_dim
self.attention_heads = attention_heads
self.num_blocks = num_blocks
self.linear_units = linear_units
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.ext_pw_kernel_size = ext_pw_kernel_size
self.ext_pw_out_channel = ext_pw_out_channel
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
self.depthwise_multiplier = depthwise_multiplier
self.activation = activation
self.conv_activation = conv_activation
self.conv_glu_type = conv_glu_type
self.bias_in_glu = bias_in_glu
self.causal = causal
self.batch_norm = batch_norm
self.cnn_layer_norm = cnn_layer_norm
self.time_reduction = time_reduction
self.input_layer = input_layer
if nemo_conv_settings is None:
self.nemo_conv_settings = {"conv_channels": 1024}
else:
self.nemo_conv_settings = nemo_conv_settings
self.chunk_size = chunk_size
self.left_chunk = left_chunk
if relative_attention_bias_args is None:
self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
else:
self.relative_attention_bias_args = relative_attention_bias_args
if activation_checkpointing is None:
self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
else:
self.activation_checkpointing = activation_checkpointing
if encoder_embedding_config is None:
self.encoder_embedding_config = {"input_size": input_size}
else:
self.encoder_embedding_config = encoder_embedding_config
class Gemma3MMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
audio_token_index (`int`, *optional*, defaults to 262145):
The audio token index to encode the audio prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Siglip-like vision config
>>> audio_config = AudioConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3mm"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
"audio_config": AudioConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
audio_config: Optional[AudioConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
boa_token_index: int = 256_001,
eoa_token_index: int = 256_002,
image_token_index: int = 262_144,
audio_token_index: int = 262_143,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
if isinstance(audio_config, dict):
audio_config = AudioConfig(**audio_config)
else:
audio_config = AudioConfig()
logger.info(
"audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.audio_config = audio_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.boa_token_index = boa_token_index
self.eoa_token_index = eoa_token_index
self.image_token_index = image_token_index
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs) |