junnei commited on
Commit
43e08ee
·
verified ·
1 Parent(s): c8dd61f

test auto modeling files

Browse files
configuration_gemma3mm.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import Gemma3TextConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+ from transformers.utils import logging
7
+ from transformers.models.siglip import SiglipVisionConfig
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class AudioConfig(PretrainedConfig):
13
+ model_type = "gemma3_audio"
14
+
15
+ def __init__(
16
+ self,
17
+ input_size=80,
18
+ attention_dim=1024,
19
+ attention_heads=16,
20
+ num_blocks=24,
21
+ linear_units=1536,
22
+ dropout_rate=0.0,
23
+ kernel_size=3,
24
+ ext_pw_kernel_size=1,
25
+ ext_pw_out_channel=1024,
26
+ depthwise_seperable_out_channel=1024,
27
+ depthwise_multiplier=1,
28
+ activation="swish",
29
+ conv_activation="swish",
30
+ conv_glu_type="swish",
31
+ bias_in_glu=True,
32
+ causal=True,
33
+ batch_norm=False,
34
+ cnn_layer_norm=True,
35
+ time_reduction=8,
36
+ input_layer="nemo_conv",
37
+ nemo_conv_settings=None,
38
+ chunk_size=-1,
39
+ left_chunk=18,
40
+ relative_attention_bias_args=None,
41
+ activation_checkpointing=None,
42
+ encoder_embedding_config=None,
43
+ **kwargs
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.input_size = input_size
48
+ self.attention_dim = attention_dim
49
+ self.attention_heads = attention_heads
50
+ self.num_blocks = num_blocks
51
+ self.linear_units = linear_units
52
+ self.dropout_rate = dropout_rate
53
+ self.kernel_size = kernel_size
54
+ self.ext_pw_kernel_size = ext_pw_kernel_size
55
+ self.ext_pw_out_channel = ext_pw_out_channel
56
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
57
+ self.depthwise_multiplier = depthwise_multiplier
58
+ self.activation = activation
59
+ self.conv_activation = conv_activation
60
+ self.conv_glu_type = conv_glu_type
61
+ self.bias_in_glu = bias_in_glu
62
+ self.causal = causal
63
+ self.batch_norm = batch_norm
64
+ self.cnn_layer_norm = cnn_layer_norm
65
+ self.time_reduction = time_reduction
66
+ self.input_layer = input_layer
67
+
68
+ if nemo_conv_settings is None:
69
+ self.nemo_conv_settings = {"conv_channels": 1024}
70
+ else:
71
+ self.nemo_conv_settings = nemo_conv_settings
72
+
73
+ self.chunk_size = chunk_size
74
+ self.left_chunk = left_chunk
75
+
76
+ if relative_attention_bias_args is None:
77
+ self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
78
+ else:
79
+ self.relative_attention_bias_args = relative_attention_bias_args
80
+
81
+ if activation_checkpointing is None:
82
+ self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
83
+ else:
84
+ self.activation_checkpointing = activation_checkpointing
85
+
86
+ if encoder_embedding_config is None:
87
+ self.encoder_embedding_config = {"input_size": input_size}
88
+ else:
89
+ self.encoder_embedding_config = encoder_embedding_config
90
+
91
+
92
+ class Gemma3MMConfig(PretrainedConfig):
93
+ r"""
94
+ This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
95
+ Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
96
+ with the defaults will yield a similar configuration to that of the PaliGemma-2B.
97
+
98
+ e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
99
+
100
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
101
+ documentation from [`PretrainedConfig`] for more information.
102
+
103
+ Args:
104
+ text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
105
+ The config object of the text backbone.
106
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
107
+ Custom vision config or dict.
108
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
109
+ Custom audio config or dict.
110
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
111
+ The number of tokens per image embedding.
112
+ boi_token_index (`int`, *optional*, defaults to 255999):
113
+ The begin-of-image token index to wrap the image prompt.
114
+ eoi_token_index (`int`, *optional*, defaults to 256000):
115
+ The end-of-image token index to wrap the image prompt.
116
+ image_token_index (`int`, *optional*, defaults to 262144):
117
+ The image token index to encode the image prompt.
118
+ audio_token_index (`int`, *optional*, defaults to 262145):
119
+ The audio token index to encode the audio prompt.
120
+ initializer_range (`float`, *optional*, defaults to 0.02):
121
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
122
+
123
+
124
+ Example:
125
+
126
+ ```python
127
+ >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
128
+
129
+ >>> # Initializing a Siglip-like vision config
130
+ >>> vision_config = SiglipVisionConfig()
131
+
132
+ >>> # Initializing a Siglip-like vision config
133
+ >>> audio_config = AudioConfig()
134
+
135
+ >>> # Initializing a Gemma3 Text config
136
+ >>> text_config = Gemma3TextConfig()
137
+
138
+ >>> # Initializing a Gemma3 gemma-3-4b style configuration
139
+ >>> configuration = Gemma3Config(vision_config, text_config)
140
+
141
+ >>> # Initializing a model from the gemma-3-4b style configuration
142
+ >>> model = Gemma3TextConfig(configuration)
143
+
144
+ >>> # Accessing the model configuration
145
+ >>> configuration = model.config
146
+ ```"""
147
+
148
+ model_type = "gemma3mm"
149
+ sub_configs = {
150
+ "text_config": Gemma3TextConfig,
151
+ "vision_config": SiglipVisionConfig,
152
+ "audio_config": AudioConfig,
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ text_config: Optional[Gemma3TextConfig] = None,
158
+ vision_config: Optional[SiglipVisionConfig] = None,
159
+ audio_config: Optional[AudioConfig] = None,
160
+ mm_tokens_per_image: int = 256,
161
+ boi_token_index: int = 255_999,
162
+ eoi_token_index: int = 256_000,
163
+ boa_token_index: int = 256_001,
164
+ eoa_token_index: int = 256_002,
165
+ image_token_index: int = 262_144,
166
+ audio_token_index: int = 262_143,
167
+ initializer_range: float = 0.02,
168
+ **kwargs,
169
+ ):
170
+ if text_config is None:
171
+ text_config = Gemma3TextConfig()
172
+ logger.info("text_config is None, using default Gemma3TextConfig vision config.")
173
+ elif isinstance(text_config, dict):
174
+ text_config = Gemma3TextConfig(**text_config)
175
+
176
+ if isinstance(vision_config, dict):
177
+ vision_config = SiglipVisionConfig(**vision_config)
178
+ else:
179
+ vision_config = SiglipVisionConfig()
180
+ logger.info(
181
+ "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
182
+ "to text tasks."
183
+ )
184
+
185
+ if isinstance(audio_config, dict):
186
+ audio_config = AudioConfig(**audio_config)
187
+ else:
188
+ audio_config = AudioConfig()
189
+ logger.info(
190
+ "audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
191
+ "to text tasks."
192
+ )
193
+
194
+ self.text_config = text_config
195
+ self.vision_config = vision_config
196
+ self.audio_config = audio_config
197
+ self.mm_tokens_per_image = mm_tokens_per_image
198
+ self.boi_token_index = boi_token_index
199
+ self.eoi_token_index = eoi_token_index
200
+ self.boa_token_index = boa_token_index
201
+ self.eoa_token_index = eoa_token_index
202
+ self.image_token_index = image_token_index
203
+ self.audio_token_index = audio_token_index
204
+ self.initializer_range = initializer_range
205
+
206
+ super().__init__(**kwargs)
modeling_gemma3mm.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_gemma3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ import copy
23
+ from collections.abc import Callable
24
+ from dataclasses import dataclass
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, HybridCache, StaticCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
35
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import (
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ is_torchdynamo_compiling,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.utils.deprecation import deprecate_kwarg
46
+ from transformers import AutoModel, AutoModelForCausalLM
47
+
48
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
49
+
50
+ from transformers import AutoConfig, AutoModelForCausalLM
51
+
52
+ from .configuration_gemma3mm import Gemma3MMConfig
53
+ from .processing_gemma3mm import InputMode
54
+ from .speech_conformer_encoder import ConformerEncoder
55
+
56
+ logger = logging.get_logger(__name__)
57
+ _CONFIG_FOR_DOC = "Gemma3MMConfig"
58
+
59
+ @dataclass
60
+ class Gemma3MMCausalLMOutputWithPast(Gemma3CausalLMOutputWithPast): # ← 부모 클래스 변경
61
+ """
62
+ Multimodal version of `Gemma3CausalLMOutputWithPast`.
63
+ Adds audio-specific hidden states.
64
+
65
+ Args:
66
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
67
+ A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
68
+ Audio hidden states produced by the audio encoder.
69
+ """
70
+ audio_hidden_states: Optional[torch.FloatTensor] = None
71
+
72
+
73
+ GEMMA3_START_DOCSTRING = r"""
74
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
79
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
80
+ and behavior.
81
+
82
+ Parameters:
83
+ config ([`Gemma3Config`]):
84
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
85
+ load the weights associated with the model, only the configuration. Check out the
86
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
87
+ """
88
+
89
+
90
+
91
+ GEMMA3_INPUTS_DOCSTRING = r"""
92
+ Args:
93
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
94
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
95
+ it.
96
+
97
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
98
+ [`PreTrainedTokenizer.__call__`] for details.
99
+
100
+ [What are input IDs?](../glossary#input-ids)
101
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
+
104
+ - 1 for tokens that are **not masked**,
105
+ - 0 for tokens that are **masked**.
106
+
107
+ [What are attention masks?](../glossary#attention-mask)
108
+
109
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
110
+ [`PreTrainedTokenizer.__call__`] for details.
111
+
112
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
113
+ `past_key_values`).
114
+
115
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
116
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
117
+ information on the default strategy.
118
+
119
+ - 1 indicates the head is **not masked**,
120
+ - 0 indicates the head is **masked**.
121
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
122
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
123
+ config.n_positions - 1]`.
124
+
125
+ [What are position IDs?](../glossary#position-ids)
126
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
127
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
128
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
129
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
130
+
131
+ Two formats are allowed:
132
+ - a [`~cache_utils.Cache`] instance, see our
133
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
134
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
135
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
136
+ cache format.
137
+
138
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
139
+ legacy cache format will be returned.
140
+
141
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
142
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
143
+ of shape `(batch_size, sequence_length)`.
144
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
145
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
146
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
147
+ model's internal embedding lookup matrix.
148
+ use_cache (`bool`, *optional*):
149
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
150
+ `past_key_values`).
151
+ output_attentions (`bool`, *optional*):
152
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
153
+ tensors for more detail.
154
+ output_hidden_states (`bool`, *optional*):
155
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
156
+ more detail.
157
+ return_dict (`bool`, *optional*):
158
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
159
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
160
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
161
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
162
+ the complete sequence length.
163
+ """
164
+
165
+ @add_start_docstrings(
166
+ """The GEMMA3 model which consists of a vision backbone and a language model.""",
167
+ GEMMA3_START_DOCSTRING,
168
+ )
169
+ class Gemma3MMForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
170
+ def __init__(self, config: Gemma3MMConfig):
171
+ super().__init__(config)
172
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
173
+ audio_config = config.audio_config.to_diff_dict()
174
+ for item in ['transformers_version', 'model_type', 'torch_dtype']:
175
+ if item in audio_config:
176
+ audio_config.pop(item)
177
+ self.audio_tower = ConformerEncoder(**audio_config)
178
+ self.audio_tower.post_init({})
179
+ self.audio_projector = nn.Sequential(
180
+ nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True),
181
+ nn.GELU(approximate='none'),
182
+ nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True)
183
+ ).to(dtype=self.dtype)
184
+
185
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
186
+ self.vocab_size = config.text_config.vocab_size
187
+
188
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
189
+
190
+ if language_model._tied_weights_keys is not None:
191
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
192
+ self.language_model = language_model
193
+
194
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
195
+
196
+ # LoRA 어댑터 설정 추가
197
+ if hasattr(config, "speech_lora") and config.speech_lora is not None:
198
+ from peft import LoraConfig, get_peft_model
199
+ import warnings
200
+
201
+ speech_lora_config = LoraConfig(
202
+ r=config.speech_lora['r'],
203
+ lora_alpha=config.speech_lora['lora_alpha'],
204
+ target_modules=config.speech_lora['layer'],
205
+ use_rslora=config.speech_lora['use_rslora'],
206
+ lora_dropout=config.speech_lora['dp'],
207
+ task_type="CAUSAL_LM",
208
+ )
209
+ self.language_model.model = get_peft_model(self.language_model.model, speech_lora_config, adapter_name="speech")
210
+
211
+ self.post_init()
212
+
213
+ def set_lora_adapter(self, adapter_name) -> None:
214
+ from peft.tuners.lora.layer import LoraLayer
215
+ for module in self.modules():
216
+ if isinstance(module, LoraLayer):
217
+ if module.merged:
218
+ warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
219
+ module.unmerge()
220
+ module.set_adapter(adapter_name)
221
+ module._disable_adapters = False
222
+
223
+ def unset_lora_adapter(self) -> None:
224
+ # Ref: peft/tuners/tuners_utils.py - enable_adapters()
225
+ # Ref: peft/tuners/lora/layer.py
226
+ from peft.tuners.lora.layer import LoraLayer
227
+ for module in self.modules():
228
+ if isinstance(module, LoraLayer):
229
+ # disable grads on all adapter layers
230
+ # TODO weijian: may use enable_adapters() instead
231
+ for layer_name in module.adapter_layer_names:
232
+ layer = getattr(module, layer_name)
233
+ layer.requires_grad_(False)
234
+ module._disable_adapters = True
235
+
236
+ def get_input_embeddings(self):
237
+ return self.language_model.get_input_embeddings()
238
+
239
+ def set_input_embeddings(self, value):
240
+ self.language_model.set_input_embeddings(value)
241
+
242
+ def get_output_embeddings(self):
243
+ return self.language_model.get_output_embeddings()
244
+
245
+ def set_output_embeddings(self, new_embeddings):
246
+ self.language_model.set_output_embeddings(new_embeddings)
247
+
248
+ def set_decoder(self, decoder):
249
+ self.language_model.set_decoder(decoder)
250
+
251
+ def get_decoder(self):
252
+ return self.language_model.get_decoder()
253
+
254
+ def _update_causal_mask(
255
+ self,
256
+ attention_mask,
257
+ token_type_ids,
258
+ past_key_values,
259
+ cache_position,
260
+ input_tensor,
261
+ is_training: bool = False,
262
+ ):
263
+ if self.config.text_config._attn_implementation == "flash_attention_2":
264
+ return attention_mask
265
+
266
+ if attention_mask is not None and attention_mask.dim() == 4:
267
+ # In this case we assume that the mask comes already in inverted
268
+ # form and requires no inversion or slicing.
269
+ return attention_mask
270
+
271
+ using_static_cache = isinstance(past_key_values, StaticCache)
272
+ min_dtype = torch.finfo(self.dtype).min
273
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
274
+ if using_static_cache:
275
+ target_length = past_key_values.get_max_cache_shape()
276
+ elif isinstance(past_key_values, HybridCache):
277
+ target_length = past_key_values.get_max_cache_shape()
278
+ else:
279
+ target_length = (
280
+ attention_mask.shape[-1]
281
+ if isinstance(attention_mask, torch.Tensor)
282
+ else cache_position[0] + sequence_length + 1
283
+ )
284
+
285
+ if attention_mask is not None and attention_mask.dim() == 4:
286
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
287
+ return attention_mask
288
+
289
+ causal_mask = torch.full(
290
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
291
+ )
292
+
293
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
294
+ if sequence_length != 1:
295
+ causal_mask = torch.triu(causal_mask, diagonal=1)
296
+
297
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
298
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
299
+
300
+ # Apply bidirectional mask on images if token type ids are provided
301
+ if token_type_ids is not None and sequence_length != 1:
302
+ token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
303
+ token_type_mask[token_type_ids == 0] = False # if text token do not change anything
304
+ token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
305
+ causal_mask = causal_mask.clone()
306
+ causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
307
+ token_type_mask, 0.0
308
+ )
309
+
310
+ if attention_mask is not None:
311
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
312
+ mask_length = attention_mask.shape[-1]
313
+
314
+ # Then apply padding mask (will mask pad tokens)
315
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
316
+ padding_mask = padding_mask == 0
317
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
318
+ padding_mask, min_dtype
319
+ )
320
+
321
+ return causal_mask
322
+
323
+ def get_image_features(self, pixel_values: torch.Tensor):
324
+ """
325
+ Projects the last hidden state from the vision model into language model space.
326
+
327
+ Args:
328
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
329
+ The tensors corresponding to the input images.
330
+ Returns:
331
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
332
+ """
333
+ vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
334
+ image_features = self.multi_modal_projector(vision_outputs)
335
+ return image_features
336
+
337
+ def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor):
338
+ """
339
+ Projects the last hidden state from the audio model into language model space.
340
+
341
+ Args:
342
+ audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`)
343
+ The tensors corresponding to the input audio features.
344
+
345
+ Returns:
346
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`).
347
+ """
348
+ audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask)
349
+ audio_outputs = self.audio_projector(audio_features)
350
+ return audio_outputs
351
+
352
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
353
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
354
+ @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ pixel_values: torch.FloatTensor = None,
359
+ input_audio_embeds: torch.FloatTensor = None,
360
+ audio_embed_sizes: torch.FloatTensor = None,
361
+ audio_attention_mask: torch.FloatTensor = None,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ input_modes: torch.LongTensor = None,
364
+ position_ids: Optional[torch.LongTensor] = None,
365
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
366
+ token_type_ids: Optional[torch.LongTensor] = None,
367
+ cache_position: Optional[torch.LongTensor] = None,
368
+ inputs_embeds: Optional[torch.FloatTensor] = None,
369
+ labels: Optional[torch.LongTensor] = None,
370
+ use_cache: Optional[bool] = None,
371
+ output_attentions: Optional[bool] = None,
372
+ output_hidden_states: Optional[bool] = None,
373
+ return_dict: Optional[bool] = None,
374
+ logits_to_keep: Union[int, torch.Tensor] = 0,
375
+ **lm_kwargs,
376
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
377
+ r"""
378
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
379
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
380
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
381
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
382
+
383
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
384
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
385
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
386
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
387
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
388
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
389
+
390
+ Returns:
391
+
392
+ Example:
393
+
394
+ ```python
395
+ >>> from PIL import Image
396
+ >>> import requests
397
+ >>> from transformers import AutoProcessor, Gemma3MMForConditionalGeneration
398
+
399
+ >>> model = Gemma3MMForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
400
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
401
+
402
+ >>> prompt = "answer en Where is the cow standing?"
403
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
404
+ >>> image = Image.open(requests.get(url, stream=True).raw)
405
+
406
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
407
+
408
+ >>> # Generate
409
+ >>> generate_ids = model.generate(**inputs, max_length=30)
410
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
411
+ "answer en Where is the cow standing?\nbeach"
412
+ ```"""
413
+
414
+ if (input_ids is None) ^ (inputs_embeds is not None):
415
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
416
+
417
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
418
+ output_hidden_states = (
419
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
420
+ )
421
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
422
+
423
+ if isinstance(input_modes, torch.Tensor):
424
+ # len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
425
+ input_modes = input_modes.unique()
426
+ if len(input_modes) != 1:
427
+ raise ValueError("Elements of input_modes should have the same value")
428
+
429
+ input_mode = InputMode(input_modes.item())
430
+
431
+ if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
432
+ self.unset_lora_adapter()
433
+ #self.set_lora_adapter('vision')
434
+ #audio_projection_mode = 'vision'
435
+ elif input_mode == InputMode.SPEECH:
436
+ self.set_lora_adapter('speech')
437
+ #audio_projection_mode = 'speech'
438
+ elif input_mode == InputMode.LANGUAGE:
439
+ self.unset_lora_adapter()
440
+ #audio_projection_mode = 'speech'
441
+ else:
442
+ raise ValueError(f"Invalid input_mode: {input_mode}")
443
+
444
+ is_training = token_type_ids is not None and labels is not None
445
+
446
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
447
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size:
448
+ special_image_mask = input_ids == self.config.image_token_index
449
+ special_audio_mask = input_ids == self.config.audio_token_index
450
+ llm_input_ids = input_ids.clone()
451
+ llm_input_ids[special_image_mask] = 0
452
+ llm_input_ids[special_audio_mask] = 0
453
+ else:
454
+ llm_input_ids = input_ids
455
+
456
+ if inputs_embeds is None:
457
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
458
+
459
+ if cache_position is None:
460
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
461
+ cache_position = torch.arange(
462
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
463
+ )
464
+
465
+ if position_ids is None:
466
+ position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
467
+
468
+ # Merge text and images
469
+ if pixel_values is not None:
470
+ image_features = self.get_image_features(pixel_values)
471
+
472
+ if input_ids is None:
473
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
474
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
475
+ )
476
+ else:
477
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
478
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
479
+
480
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
481
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
482
+ raise ValueError(
483
+ f"Number of images does not match number of special image tokens in the input text. "
484
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
485
+ "tokens from image embeddings."
486
+ )
487
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
488
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
489
+
490
+ # Merge text and audios
491
+ if input_audio_embeds is not None:
492
+ audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes)
493
+ if input_ids is None:
494
+ special_audio_mask = inputs_embeds == self.get_input_embeddings()(
495
+ torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device)
496
+ )
497
+ else:
498
+ special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1)
499
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
500
+
501
+ masked_audio_features = []
502
+ for i, size in enumerate(audio_embed_sizes):
503
+ masked_audio_features.append(audio_features[i, :size, :])
504
+ masked_audio_features = torch.cat(masked_audio_features, dim=0)
505
+
506
+ if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel():
507
+ audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
508
+ masked_audio_size = audio_embed_sizes.sum()[0]
509
+ raise ValueError(
510
+ f"Number of images does not match number of special image tokens in the input text. "
511
+ f"Got {audio_tokens_in_text} image tokens in the text but {masked_audio_size} "
512
+ "tokens from image embeddings."
513
+ )
514
+ masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
515
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features)
516
+
517
+ # mask out pad-token-ids in labels for BC
518
+ if labels is not None and self.pad_token_id in labels:
519
+ logger.warning_once(
520
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
521
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
522
+ )
523
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
524
+
525
+ causal_mask = self._update_causal_mask(
526
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
527
+ )
528
+ outputs = self.language_model(
529
+ attention_mask=causal_mask,
530
+ position_ids=position_ids,
531
+ past_key_values=past_key_values,
532
+ inputs_embeds=inputs_embeds,
533
+ use_cache=use_cache,
534
+ output_attentions=output_attentions,
535
+ output_hidden_states=output_hidden_states,
536
+ return_dict=return_dict,
537
+ cache_position=cache_position,
538
+ logits_to_keep=logits_to_keep,
539
+ **lm_kwargs,
540
+ )
541
+
542
+ logits = outputs.logits
543
+ loss = None
544
+ if labels is not None:
545
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
546
+ logits = logits.float()
547
+ shift_logits = logits[..., :-1, :]
548
+ shift_labels = labels[..., 1:]
549
+ if attention_mask is not None:
550
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
551
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
552
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
553
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
554
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
555
+ else:
556
+ shift_logits = shift_logits.contiguous()
557
+ shift_labels = shift_labels.contiguous()
558
+ # Flatten the tokens
559
+ loss_fct = nn.CrossEntropyLoss()
560
+
561
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
562
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
563
+ loss = loss_fct(flat_logits, flat_labels)
564
+ if not return_dict:
565
+ output = (logits,) + outputs[1:]
566
+ return (loss,) + output if loss is not None else output
567
+
568
+ return Gemma3CausalLMOutputWithPast(
569
+ loss=loss,
570
+ logits=logits,
571
+ past_key_values=outputs.past_key_values,
572
+ hidden_states=outputs.hidden_states,
573
+ attentions=outputs.attentions,
574
+ image_hidden_states=image_features if pixel_values is not None else None,
575
+ audio_hidden_states=audio_features if input_audio_embeds is not None else None,
576
+ )
577
+
578
+ def prepare_inputs_for_generation(
579
+ self,
580
+ input_ids,
581
+ past_key_values=None,
582
+ input_modes=None,
583
+ inputs_embeds=None,
584
+ cache_position=None,
585
+ position_ids=None,
586
+ pixel_values=None,
587
+ input_audio_embeds=None,
588
+ audio_embed_sizes=None,
589
+ audio_attention_mask=None,
590
+ attention_mask=None,
591
+ token_type_ids=None,
592
+ use_cache=True,
593
+ logits_to_keep=None,
594
+ labels=None,
595
+ **kwargs,
596
+ ):
597
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
598
+ model_inputs = self.language_model.prepare_inputs_for_generation(
599
+ input_ids,
600
+ past_key_values=past_key_values,
601
+ input_modes=input_modes,
602
+ inputs_embeds=inputs_embeds,
603
+ attention_mask=attention_mask,
604
+ position_ids=position_ids,
605
+ cache_position=cache_position,
606
+ use_cache=use_cache,
607
+ logits_to_keep=logits_to_keep,
608
+ token_type_ids=token_type_ids,
609
+ **kwargs,
610
+ )
611
+
612
+ # position_ids in Gemma3 are 1-indexed
613
+ if model_inputs.get("position_ids") is not None:
614
+ model_inputs["position_ids"] += 1
615
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
616
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
617
+ if cache_position[0] == 0:
618
+ model_inputs["pixel_values"] = pixel_values
619
+ model_inputs["input_audio_embeds"] = input_audio_embeds
620
+ model_inputs["audio_embed_sizes"] = audio_embed_sizes
621
+ model_inputs["audio_attention_mask"] = audio_attention_mask
622
+ model_inputs["input_modes"] = input_modes
623
+ is_training = token_type_ids is not None and labels is not None
624
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
625
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
626
+ causal_mask = self._update_causal_mask(
627
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
628
+ )
629
+ model_inputs["attention_mask"] = causal_mask
630
+
631
+ return model_inputs
632
+
633
+ def tie_weights(self):
634
+ return self.language_model.tie_weights()
635
+
636
+
637
+ AutoConfig.register("gemma3mm", Gemma3MMConfig)
638
+ AutoModel.register("gemma3mm", Gemma3MMForConditionalGeneration)
639
+ Gemma3MMConfig.register_for_auto_class()
640
+ Gemma3MMForConditionalGeneration.register_for_auto_class()
processing_gemma3mm.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Optional, Union, Tuple
3
+ from math import ceil
4
+
5
+ import numpy as np
6
+ import torch
7
+ import scipy
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from enum import Enum
11
+
12
+ from transformers import AutoFeatureExtractor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
15
+ from transformers.image_utils import ImageInput, make_nested_list_of_images
16
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
17
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
+ from transformers.utils import to_py_obj, TensorType
19
+ from transformers.audio_utils import AudioInput
20
+
21
+ class InputMode(Enum):
22
+ LANGUAGE = 0
23
+ VISION = 1
24
+ SPEECH = 2
25
+ VISION_SPEECH = 3
26
+
27
+ class Gemma3ImagesKwargs(ImagesKwargs):
28
+ do_pan_and_scan: Optional[bool]
29
+ pan_and_scan_min_crop_size: Optional[int]
30
+ pan_and_scan_max_num_crops: Optional[int]
31
+ pan_and_scan_min_ratio_to_activate: Optional[float]
32
+ do_convert_rgb: Optional[bool]
33
+
34
+
35
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
36
+ images_kwargs: Gemma3ImagesKwargs
37
+ _defaults = {
38
+ "text_kwargs": {
39
+ "padding": False,
40
+ },
41
+ "images_kwargs": {
42
+ "do_pan_and_scan": False,
43
+ "pan_and_scan_min_crop_size": 256,
44
+ "pan_and_scan_max_num_crops": 4,
45
+ "pan_and_scan_min_ratio_to_activate": 1.2,
46
+ },
47
+ }
48
+
49
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
50
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
51
+ Args:
52
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
53
+ n_fft (int): FFT size. int > 0 [scalar]
54
+ n_mel (int): Mel filter size. int > 0 [scalar]
55
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
56
+ float >= 0 [scalar]
57
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
58
+ float >= 0 [scalar]
59
+ Returns
60
+ out (numpy.ndarray): Mel transform matrix
61
+ [shape=(n_mels, 1 + n_fft/2)]
62
+ """
63
+
64
+ bank_width = int(n_fft // 2 + 1)
65
+ if fmax is None:
66
+ fmax = sample_rate / 2
67
+ if fmin is None:
68
+ fmin = 0
69
+ assert fmin >= 0, "fmin cannot be negtive"
70
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
71
+
72
+ def mel(f):
73
+ return 1127.0 * np.log(1.0 + f / 700.0)
74
+
75
+ def bin2mel(fft_bin):
76
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
77
+
78
+ def f2bin(f):
79
+ return int((f * n_fft / sample_rate) + 0.5)
80
+
81
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
82
+ klo = f2bin(fmin) + 1
83
+ khi = f2bin(fmax)
84
+
85
+ khi = max(khi, klo)
86
+
87
+ # Spec 2: SpeechLib uses trianges in Mel space
88
+ mlo = mel(fmin)
89
+ mhi = mel(fmax)
90
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
91
+ ms = (mhi - mlo) / (n_mels + 1)
92
+
93
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
94
+ for m in range(0, n_mels):
95
+ left = m_centers[m]
96
+ center = m_centers[m + 1]
97
+ right = m_centers[m + 2]
98
+ for fft_bin in range(klo, khi):
99
+ mbin = bin2mel(fft_bin)
100
+ if left < mbin < right:
101
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
102
+
103
+ return matrix
104
+
105
+
106
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
107
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
108
+ feature_extractor_type = "Gemma3AudioFeatureExtractor"
109
+ def __init__(self, **kwargs):
110
+ self.sampling_rate = kwargs.pop("sampling_rate", 16000)
111
+ self.feature_size = kwargs.pop("feature_size", 80)
112
+ self.padding_value = kwargs.pop("padding_value", 0.0)
113
+ super().__init__(sampling_rate=self.sampling_rate, feature_size=self.feature_size, padding_value=self.padding_value, **kwargs)
114
+
115
+ self.compression_rate = kwargs.get("audio_compression_rate", 8)
116
+ self.qformer_compression_rate = kwargs.get("audio_downsample_rate", 1)
117
+ self.feat_stride = kwargs.get("audio_feat_stride", 1)
118
+
119
+ self._eightk_method = "fillzero"
120
+ self._mel = speechlib_mel(self.sampling_rate, 512, self.feature_size, fmin=None, fmax=self.sampling_rate//2-self.feature_size-230).T
121
+
122
+ self._hamming400 = np.hamming(400) # for 16k audio
123
+ self._hamming200 = np.hamming(200) # for 8k audio
124
+
125
+ def duration_to_frames(self, duration):
126
+ """duration in s, estimated frames"""
127
+ frame_rate = 10
128
+
129
+ num_frames = duration * 1000 // frame_rate
130
+ return num_frames
131
+
132
+ def __call__(
133
+ self,
134
+ audios: List[AudioInput],
135
+ return_tensors: Optional[Union[str, TensorType]] = None,
136
+ ):
137
+ # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
138
+ returned_input_audio_embeds = []
139
+ returned_audio_embed_sizes = []
140
+ audio_frames_list = []
141
+
142
+ for audio_data, sample_rate in audios:
143
+ audio_embeds = self._extract_features(audio_data, sample_rate)
144
+ audio_frames = len(audio_embeds) * self.audio_feat_stride
145
+ audio_embed_size = self._compute_audio_embed_size(audio_frames)
146
+ returned_input_audio_embeds.append(torch.tensor(audio_embeds))
147
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
148
+ audio_frames_list.append(audio_frames)
149
+
150
+ returned_input_audio_embeds = pad_sequence(
151
+ returned_input_audio_embeds, batch_first=True
152
+ )
153
+ returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
154
+ audio_frames = torch.tensor(audio_frames_list)
155
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
156
+
157
+ data = {
158
+ "input_audio_embeds": returned_input_audio_embeds,
159
+ "audio_embed_sizes": returned_audio_embed_sizes,
160
+ }
161
+ if returned_audio_attention_mask is not None:
162
+ data["audio_attention_mask"] = returned_audio_attention_mask
163
+
164
+ return BatchFeature(data=data, tensor_type=return_tensors)
165
+
166
+ def _extract_spectrogram(self, wav, fs):
167
+ """Extract spectrogram features from waveform.
168
+ Args:
169
+ wav (1D array): waveform of the input
170
+ fs (int): sampling rate of the waveform, 16000 or 8000.
171
+ If fs=8000, the waveform will be resampled to 16000Hz.
172
+ Output:
173
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
174
+ D=80, and T is the number of frames.
175
+ """
176
+ if wav.ndim > 1:
177
+ wav = np.squeeze(wav)
178
+
179
+ # by default, we extract the mean if stereo
180
+ if len(wav.shape) == 2:
181
+ wav = wav.mean(1)
182
+
183
+ # Resample to 16000 or 8000 if needed
184
+ if fs > 16000:
185
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
186
+ fs = 16000
187
+ elif 8000 < fs < 16000:
188
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
189
+ fs = 8000
190
+ elif fs < 8000:
191
+ raise RuntimeError(f"Unsupported sample rate {fs}")
192
+
193
+ if fs == 8000:
194
+ if self._eightk_method == "resample":
195
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
196
+ # extraction
197
+ wav = scipy.signal.resample_poly(wav, 2, 1)
198
+ fs = 16000
199
+ # Do nothing here for fillzero method
200
+ elif fs != 16000:
201
+ # Input audio is not a supported sample rate.
202
+ raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
203
+
204
+ preemphasis = 0.97
205
+
206
+ if fs == 8000:
207
+ n_fft = 256
208
+ win_length = 200
209
+ hop_length = 80
210
+ fft_window = self._hamming200
211
+ elif fs == 16000:
212
+ n_fft = 512
213
+ win_length = 400
214
+ hop_length = 160
215
+ fft_window = self._hamming400
216
+
217
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
218
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
219
+ # Here we don't use stride_tricks since the input array may not satisfy
220
+ # memory layout requirement and we need writeable output
221
+ # Here we only use list of views before copy to desination
222
+ # so it is more efficient than broadcasting
223
+ y_frames = np.array(
224
+ [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
225
+ dtype=np.float32,
226
+ )
227
+
228
+ # Spec 2: SpeechLib applies preemphasis within each batch
229
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
230
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
231
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
232
+
233
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
234
+
235
+ if fs == 8000:
236
+ # Need to pad the output to look like 16 kHz data but with zeros in
237
+ # the 4 to 8 kHz bins.
238
+ frames, bins = S.shape
239
+ padarray = np.zeros((frames, bins))
240
+ S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
241
+
242
+ spec = np.abs(S).astype(np.float32)
243
+ return spec
244
+
245
+ def _extract_features(self, wav, fs):
246
+ """Extract log filterbank features from waveform.
247
+ Args:
248
+ wav (1D array): waveform of the input
249
+ fs (int): sampling rate of the waveform, 16000 or 8000.
250
+ If fs=8000, the waveform will be resampled to 16000Hz.
251
+ Output:
252
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
253
+ D=80, and T is the number of frames.
254
+ """
255
+ spec = self._extract_spectrogram(wav, fs)
256
+ spec_power = spec**2
257
+
258
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
259
+ log_fbank = np.log(fbank_power).astype(np.float32)
260
+
261
+ return log_fbank
262
+
263
+ def _compute_audio_embed_size(self, audio_frames):
264
+ integer = audio_frames // self.audio_compression_rate
265
+ remainder = audio_frames % self.audio_compression_rate
266
+
267
+ result = integer if remainder == 0 else integer + 1
268
+
269
+ integer = result // self.audio_downsample_rate
270
+ remainder = result % self.audio_downsample_rate
271
+ result = integer if remainder == 0 else integer + 1 # qformer compression
272
+
273
+ return result
274
+
275
+ class Gemma3MMProcessor(ProcessorMixin):
276
+ attributes = ["image_processor", "feature_extractor", "tokenizer"]
277
+ valid_kwargs = ["chat_template", "image_seq_length"]
278
+ image_processor_class = "AutoImageProcessor"
279
+ feature_extractor_class = "Gemma3AudioFeatureExtractor"
280
+ tokenizer_class = "AutoTokenizer"
281
+
282
+ def __init__(
283
+ self,
284
+ image_processor,
285
+ feature_extractor,
286
+ tokenizer,
287
+ chat_template=None,
288
+ image_seq_length: int = 256,
289
+ **kwargs,
290
+ ):
291
+ self.image_seq_length = image_seq_length
292
+ self.image_token_id = tokenizer.image_token_id
293
+ self.boi_token = tokenizer.boi_token
294
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
295
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
296
+
297
+ self.audio_token_id = tokenizer.audio_token_id
298
+ self.boa_token = tokenizer.boa_token
299
+ self.eoa_token = tokenizer.eoa_token
300
+ self.audio_token = tokenizer.audio_token
301
+
302
+ super().__init__(
303
+ image_processor=image_processor,
304
+ feature_extractor=feature_extractor,
305
+ tokenizer=tokenizer,
306
+ chat_template=chat_template,
307
+ **kwargs,
308
+ )
309
+
310
+ def __call__(
311
+ self,
312
+ images: ImageInput = None,
313
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
314
+ videos=None,
315
+ audios: List[AudioInput] = None,
316
+ **kwargs: Unpack[Gemma3ProcessorKwargs],
317
+ ) -> BatchFeature:
318
+ if text is None and images is None:
319
+ raise ValueError("Provide at least one of `text` or `images`.")
320
+
321
+ output_kwargs = self._merge_kwargs(
322
+ Gemma3ProcessorKwargs,
323
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
324
+ **kwargs,
325
+ )
326
+
327
+ if isinstance(text, str):
328
+ text = [text]
329
+ elif not isinstance(text, list) and not isinstance(text[0], str):
330
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
331
+
332
+ image_inputs = {}
333
+ if images is not None:
334
+ batched_images = make_nested_list_of_images(images)
335
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
336
+
337
+ # Create empty text to be replaced with placeholders
338
+ if not text:
339
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
340
+
341
+ if len(batched_images) != len(text):
342
+ raise ValueError(
343
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
344
+ )
345
+
346
+ # Replace image tokens by the full expanded sequence
347
+ batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
348
+ text_with_crops = text
349
+ for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
350
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
351
+
352
+ if len(images) != len(image_indexes):
353
+ raise ValueError(
354
+ f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
355
+ )
356
+
357
+ # Insert additional image tokens for Pan-and-Scan crops
358
+ for num, idx in reversed(list(zip(num_crops, image_indexes))):
359
+ if num:
360
+ formatted_image_text = (
361
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
362
+ + " ".join([self.boi_token] * num)
363
+ )
364
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
365
+ text_with_crops[batch_idx] = prompt
366
+
367
+ # Expand placeholder image tokens to the full image token sequence
368
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
369
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
370
+
371
+ audio_inputs = {}
372
+ if audios is not None:
373
+ def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
374
+ parts = prompt.split(boa_token)
375
+ result = ""
376
+ for i in range(len(parts) - 1):
377
+ result += parts[i]
378
+ if i < len(audio_sequences):
379
+ result += audio_sequences[i]
380
+ else:
381
+ result += boa_token
382
+ result += parts[-1]
383
+ return result
384
+
385
+ full_audio_sequences = []
386
+ audio_inputs = self.feature_extractor(audios)
387
+
388
+ for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
389
+ audio_tokens_expanded = "".join([self.audio_token] * embed_size)
390
+ full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
391
+ full_audio_sequences.append(full_audio_sequence)
392
+
393
+ text = [replace_tokens_sequentially(prompt, self.boa_token, [audio_sequences]) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
394
+
395
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
396
+
397
+ # Add token type ids manually, as tokenizer can't do arbitrary position token types
398
+ array_ids = np.array(text_inputs["input_ids"])
399
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
400
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
401
+ mm_token_type_ids[array_ids == self.audio_token_id] = 2
402
+
403
+ has_vision_ids = np.any(mm_token_type_ids == 1, axis=1)
404
+ has_audio_ids = np.any(mm_token_type_ids == 2, axis=1)
405
+
406
+ input_modes = (has_audio_ids << 1) | has_vision_ids
407
+
408
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
409
+ text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
410
+ text_inputs["input_modes"] = input_modes.tolist()
411
+
412
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs, }, tensor_type=return_tensors)
413
+
414
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
415
+ def batch_decode(self, *args, **kwargs):
416
+ """
417
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
418
+ refer to the docstring of this method for more information.
419
+ """
420
+ return self.tokenizer.batch_decode(*args, **kwargs)
421
+
422
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
423
+ def decode(self, *args, **kwargs):
424
+ """
425
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
426
+ the docstring of this method for more information.
427
+ """
428
+ return self.tokenizer.decode(*args, **kwargs)
429
+
430
+ @property
431
+ def model_input_names(self):
432
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
433
+ image_processor_input_names = self.image_processor.model_input_names
434
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
435
+
436
+ AutoFeatureExtractor.register("Gemma3AudioFeatureExtractor", Gemma3AudioFeatureExtractor)
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff