# coding=utf-8 # Copyright 2024 Google AI, LAION team. team. All rights reserved. # # This code is based on open_clip framework. It has been modified from its # original forms to accommodate minor architectural differences compared # to the original MaMMUT model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MaMMUT configuration.""" from transformers import (CLIPConfig, CLIPTextConfig, CLIPVisionConfig, PretrainedConfig, AutoConfig) from typing import Callable, List, Optional, Sequence, Tuple, Union from transformers.utils import logging logger = logging.get_logger(__name__) class MultimodalConfig(PretrainedConfig): model_type = "mammut_text_model" def __init__( self, mlp_ratio: int = 4, dim_head: int = 64, heads: int = 8, n_queries: int = 256, attn_pooler_heads: int = 8, cross_attn_ratio: int = 1, does_full_decoding: bool = False, output_tokens: bool = False, has_mlp: bool = True, context_length: int = 77, vocab_size: int = 49408, hidden_size: int = 1024, layers: int = 12, batch_first: bool = True, **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] ): super().__init__() self.mlp_ratio = mlp_ratio self.dim_head = dim_head self.heads = heads self.n_queries = n_queries self.attn_pooler_heads = attn_pooler_heads self.cross_attn_ratio = cross_attn_ratio self.does_full_decoding = does_full_decoding self.output_tokens = output_tokens self.has_mlp = has_mlp self.context_length = context_length self.vocab_size = vocab_size self.width = hidden_size self.layers = layers self.batch_first = batch_first for key, value in kwargs.items(): setattr(self, key, value) class MammutTextConfig(MultimodalConfig,CLIPTextConfig): model_type = "mammut_text_model" base_config_key = "text_config" def __init__( self, mlp_ratio: int = 4, num_attention_heads: int = 8, n_queries: int = 256, attn_pooler_heads: int = 8, cross_attn_ratio: int = 1, does_full_decoding: bool = False, output_tokens: bool = False, has_mlp: bool = True, max_position_embeddings: int = 77, vocab_size: int = 49408, num_hidden_layers: int = 12, hidden_size: int = 1024, attention_dropout: float = 0.0, hidden_act: str = "gelu", layer_norm_eps: float = 1e-5, intermediate_size: Optional[int] = None, initializer_factor: float = 0.02, logit_scale_init_value: float = 2.6592, **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] ): super().__init__( mlp_ratio=mlp_ratio, num_attention_heads=num_attention_heads, n_queries=n_queries, attn_pooler_heads=attn_pooler_heads, cross_attn_ratio=cross_attn_ratio, does_full_decoding=does_full_decoding, output_tokens=output_tokens, has_mlp=has_mlp, vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, attention_dropout=attention_dropout, logit_scale_init_value=logit_scale_init_value, max_position_embeddings=max_position_embeddings, layer_norm_eps=layer_norm_eps, intermediate_size=intermediate_size, initializer_factor=initializer_factor, hidden_act=hidden_act, **kwargs ) self.logit_scale_init_value = logit_scale_init_value self.does_full_decoding = does_full_decoding self.output_tokens = output_tokens self.architectures = ["MammutTextModel"] self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads class MammutVisionConfig(CLIPVisionConfig): model_type = "mammut_vision_model" base_config_key = "vision_config" def __init__( self, mlp_ratio: int = 4, dim_head: int = 64, num_attention_heads: int = 8, n_queries: int = 256, attn_pooler_heads: int = 8, cross_attn_ratio: int = 1, does_full_decoding: bool = False, output_tokens: bool = False, has_mlp: bool = True, image_size: int = 224, patch_size: int = 16, width: int = 1024, layers: int = 12, **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] ): super().__init__( mlp_ratio=mlp_ratio, dim_head=dim_head, num_attention_heads=num_attention_heads, n_queries=n_queries, attn_pooler_heads=attn_pooler_heads, cross_attn_ratio=cross_attn_ratio, does_full_decoding=does_full_decoding, output_tokens=output_tokens, has_mlp=has_mlp, image_size=image_size, patch_size=patch_size, width=width, layers=layers, **kwargs ) self.num_attention_heads = num_attention_heads class MammutConfig(CLIPConfig): model_type = "mammut" def __init__( self, mlp_ratio: int = 4, dim_head: int = 64, num_attention_heads: int = 8, n_queries: int = 256, attn_pooler_heads: int = 8, cross_attn_ratio: int = 1, does_full_decoding: bool = False, output_tokens: bool = False, has_mlp: bool = True, text_config: Optional[MammutTextConfig] = None, vision_config: Optional[MammutVisionConfig] = None, projection_dim: int = 768, logit_scale_init_value: float = 2.6592, **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] ): kwargs["architectures"] = ["MammutModel"] super().__init__( mlp_ratio=mlp_ratio, dim_head=dim_head, num_attention_heads=num_attention_heads, n_queries=n_queries, attn_pooler_heads=attn_pooler_heads, cross_attn_ratio=cross_attn_ratio, does_full_decoding=does_full_decoding, output_tokens=output_tokens, has_mlp=has_mlp, **kwargs ) self.text_config = MammutTextConfig(**text_config) if text_config is not None else MammutTextConfig() self.vision_config = MammutVisionConfig(**vision_config) if vision_config is not None else MammutVisionConfig() self.text_config.architectures = ["MammutTextModel"] self.vision_config.architectures = ["MammutVisionModel"] self.projection_dim = projection_dim self.hidden_size = self.text_config.hidden_size self.logit_scale_init_value = logit_scale_init_value self.architectures = ["MammutModel"] self.does_full_decoding = does_full_decoding self.output_tokens = output_tokens def _post_init(self): if self.logit_scale_init_value is not None: setattr(self.text_config, "logit_scale_init_value", self.logit_scale_init_value) super()._post_init() AutoConfig.register("mammut", MammutConfig)