# coding=utf-8 # Copyright 2024 Google AI, LAION team. team. All rights reserved. # # This code is based on open_clip framework. It has been modified from its # original forms to accommodate minor architectural differences compared # to the original MaMMUT model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch MaMMUT model.""" from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from .configuration_mammut import MammutTextConfig, MammutVisionConfig, MammutConfig from transformers.models.clip.modeling_clip import ( CLIPAttention, CLIPMLP, CLIPEncoderLayer, CLIPTextModel, CLIPVisionModel, CLIPVisionModelOutput, CLIPVisionTransformer, CLIPTextModelOutput, CLIPOutput, CLIPModel, CLIPPreTrainedModel, CLIPVisionEmbeddings, CLIPEncoder, eager_attention_forward ) # noqa: E501 from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from transformers.generation import GenerateDecoderOnlyOutput from dataclasses import dataclass from typing import Optional, Tuple, Union from transformers import AutoModel import logging from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers import ( BeamSearchScorer, LogitsProcessorList, TopPLogitsWarper, TopKLogitsWarper, RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, StoppingCriteriaList ) log = logging.getLogger(__name__) class MammutCrossAttnLayer(nn.Module): def __init__(self, config: MammutTextConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = MammutAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1_kv = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, print0_hidden_states: bool = False, ) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) if k_x is not None and v_x is not None: k_x = self.layer_norm1_kv(k_x) v_x = self.layer_norm1_kv(v_x) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, keys=k_x, values=v_x, print0_hidden_states=print0_hidden_states, ) hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class MammutAttention(CLIPAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Union[MammutTextConfig, MammutVisionConfig]): super().__init__(config) self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 # self.scale = 1 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.training = False # Set to True by default, can be changed during training or evaluation def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, keys: Optional[torch.Tensor] = None, values: Optional[torch.Tensor] = None, print0_hidden_states: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape if keys is None and values is None: keys = hidden_states values = hidden_states #TODO: CLIP attention interface # keys = self.k_proj(keys) # values = self.v_proj(values) # if print0_hidden_states: # # print("head_dim:", self.head_dim) # print("query shape:", queries.shape) # print("key shape:", keys.shape) # print("value shape:", values.shape) # queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) # keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) # values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) # CLIP text model uses both `causal_attention_mask` and `attention_mask` # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` # if self.config._attn_implementation == "flash_attention_2": # self.is_causal = causal_attention_mask is not None # else: # if attention_mask is not None and causal_attention_mask is not None: # attention_mask = attention_mask + causal_attention_mask # elif causal_attention_mask is not None: # attention_mask = causal_attention_mask # attention_interface: Callable = eager_attention_forward # if self.config._attn_implementation != "eager": # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = F.multi_head_attention_forward( query=hidden_states.permute(1, 0, 2), # (seq_length, batch_size, embed_dim) key=keys.permute(1, 0, 2) if keys is not None else hidden_states.permute(1, 0, 2), value=values.permute(1, 0, 2) if values is not None else hidden_states.permute(1, 0, 2), embed_dim_to_check=embed_dim, num_heads=self.num_heads, in_proj_weight=torch.cat( [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 ), in_proj_bias=torch.cat( [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias], dim=0 ) if self.q_proj.bias is not None else None, bias_k=None, bias_v=None, add_zero_attn=False, attn_mask=attention_mask, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, is_causal=self.is_causal, dropout_p=0.0 if not self.training else self.dropout, out_proj_weight=self.out_proj.weight, out_proj_bias=self.out_proj.bias, training=self.training, # Use the training flag to control dropout ) # attn_output, attn_weights = attention_interface( # self, # queries, # (seq_length, batch_size, embed_dim) # keys, # values, # attention_mask, # is_causal=self.is_causal, # scaling=self.scale, # dropout=0.0 if not self.training else self.dropout, # output_attentions=output_attentions, # ) # attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() # attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights class MammutEncoderLayer(CLIPEncoderLayer): def __init__(self, config: MammutTextConfig, has_mlp: bool = True): super().__init__(config) self.embed_dim = config.hidden_size self.self_attn = MammutAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config) if has_mlp else None self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, print_hidden_states: bool = False, ) -> Tuple[torch.FloatTensor]: """ Forward pass for the encoder layer. Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. causal_attention_mask (`torch.FloatTensor`, *optional*): causal attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=None, output_attentions=output_attentions, print0_hidden_states=print_hidden_states, ) hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) if self.mlp is not None else hidden_states hidden_states = residual + hidden_states return hidden_states class MammutMultimodalEncoder(nn.Module): does_full_decoding: torch.jit.Final[bool] def __init__( self, config: MammutConfig, ): super().__init__() self.config = config self.n_cross_attn, _ = divmod(config.num_hidden_layers, config.cross_attn_ratio) self.cross_step, _ = divmod(config.num_hidden_layers, self.n_cross_attn) self.does_full_decoding = config.does_full_decoding self.output_tokens = config.output_tokens self.batch_first = config.batch_first self.context_length = config.max_position_embeddings self.layers = nn.ModuleList([]) self.cross_attn = nn.ModuleList([]) num_cross_attn = 0 for l_idx in range(config.num_hidden_layers): _, r = divmod(l_idx, self.cross_step) has_cross_attn = r == 0 layer = MammutEncoderLayer(config) self.layers.append(layer) if has_cross_attn: num_cross_attn += 1 cross_attn_layer = MammutCrossAttnLayer(config) self.cross_attn.append(cross_attn_layer) def forward( self, text_embeds: torch.Tensor, img_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = text_embeds seq_len = hidden_states.shape[1] if self.batch_first else hidden_states.shape[0] if causal_attention_mask is None: causal_attention_mask = self.build_causal_mask() else: causal_attention_mask = causal_attention_mask.to(dtype=hidden_states.dtype) if attention_mask is None: attention_mask = causal_attention_mask else: attention_mask = attention_mask + causal_attention_mask if img_embeds is not None: img_embeds = img_embeds.to(dtype=hidden_states.dtype) k_x = img_embeds v_x = img_embeds else: k_x = None v_x = None if img_embeds is not None: attention_mask = attention_mask[:seq_len, :seq_len] for i, layer in enumerate(self.layers): cross_attn_idx, r = divmod(i, self.cross_step) has_cross_attn = r == 0 and img_embeds is not None if i == 0: print_hidden_states = True else: print_hidden_states = False hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask if img_embeds is not None else None, causal_attention_mask=None, output_attentions=output_attentions, print_hidden_states=print_hidden_states, ) if has_cross_attn: cross_attn = self.cross_attn[cross_attn_idx] hidden_states = cross_attn( hidden_states=hidden_states, k_x=k_x, v_x=v_x, print0_hidden_states=i== 0, # attention_mask=attention_mask, # causal_attention_mask=causal_attention_mask, ) if output_hidden_states: encoder_states = tuple(encoder_states) if self.does_full_decoding: encoder_states = encoder_states[:self.n_cross_attn + 1] else: encoder_states = encoder_states[:self.config.text_config.num_hidden_layers] else: encoder_states = None return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) def build_causal_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def build_attn_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @dataclass class MammutPoolingOutput(BaseModelOutputWithPooling): """ Base class for outputs of the Mammut model. """ last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None output_ids: Optional[torch.Tensor] = None pooler_output: Optional[torch.FloatTensor] = None class MammutMultimodalEmbeddings(nn.Module): def __init__(self, config: MammutTextConfig): super().__init__() self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embedding = nn.Embedding( config.max_position_embeddings, config.hidden_size ) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] max_position_embedding = self.position_embedding.weight.shape[0] if seq_length > max_position_embedding: raise ValueError( f"Sequence length must be less than max_position_embeddings (got `sequence length`: " f"{seq_length} and max_position_embeddings: {max_position_embedding}" ) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): if pool_type == 'first': pooled, tokens = x[:, 0], x[:, 1:] elif pool_type == 'last': pooled, tokens = x[:, -1], x[:, :-1] elif pool_type == 'argmax': # take features from the eot embedding (eot_token is the highest number in each sequence) assert text is not None pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x else: pooled = tokens = x return pooled, tokens class MammutMultimodalTransformer(nn.Module): def __init__(self, config: MammutTextConfig, output_tokens=True): super().__init__() self.config = config embed_dim = config.hidden_size self.encoder = MammutMultimodalEncoder(config) self.text_projection = nn.Linear( config.hidden_size, config.vocab_size, bias=False ) if config.hidden_size is not None else None self.final_layer_norm = nn.LayerNorm( embed_dim, eps=config.layer_norm_eps ) # self.init_weights() self.does_full_decoding = config.does_full_decoding self.context_length = config.context_length self.vocab_size = config.vocab_size width = config.hidden_size self.batch_first = config.batch_first self.has_mlp = config.has_mlp self.cross_attn_ratio = config.cross_attn_ratio self.cross_step = config.cross_attn_ratio self.n_cross_attn = config.num_hidden_layers // config.cross_attn_ratio vocab_size = config.vocab_size self.output_tokens = output_tokens if self.does_full_decoding: self.num_pos = self.context_length self.embeddings = MammutMultimodalEmbeddings(config) else: self.num_pos = None self.embeddings = None def init_weights(self): self.final_layer_norm.weight.data.fill_(1.0) self.final_layer_norm.bias.data.zero_() log.info("MammutMultimodalTransformer weights initialized.") def forward( self, img_embs: torch.Tensor, text_embs: Optional[torch.Tensor] = None, output_tokens: Optional[bool] = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, ) -> Union[CLIPVisionModelOutput, CLIPTextModelOutput]: if text_embs is not None: if self.embeddings is not None: # print("text_embs shape:", text_embs.shape) text_embs = self.embeddings( input_ids=text_embs, position_ids=position_ids, # inputs_embeds=img_embs if img_embs is not None else None, ) if self.does_full_decoding: text_embs = text_embs[:, :self.context_length, :] text_embs = self.encoder( text_embeds=text_embs, img_embeds=img_embs, attention_mask=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) text_embs = text_embs.last_hidden_state if self.does_full_decoding: text_embs = text_embs[:, :self.context_length, :] else: text_embs = text_embs[:, 0, :] if self.text_projection is not None: output_ids = self.text_projection(text_embs) else: output_ids = text_embs if output_tokens: return MammutPoolingOutput( last_hidden_state=text_embs, # Last hidden state is the text embeddings hidden_states=None, # No hidden states in this implementation attentions=None, # No attentions in this implementation output_ids=output_ids, # Placeholder for output tokens pooler_output=text_embs, # Pooler output is the text embeddings ) return MammutPoolingOutput( last_hidden_state=text_embs, # Last hidden state is the text embeddings pooler_output=text_embs, hidden_states=None, # No hidden states in this implementation attentions=None, # No attentions in this implementation ) def build_causal_mask(self, seq_len: Optional[int] = None, device: Optional[torch.device] = None) -> torch.Tensor: if seq_len is None: seq_len = self.context_length if self.does_full_decoding else self.config.context_length if device is None: device = torch.device("cpu") mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len) return mask def build_attn_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask class MammutMultimodalModel(CLIPTextModel): """ Mammut multimodal model with text and vision encoders. """ config_class = MammutTextConfig base_model_prefix = "mammut_multimodal" def __init__(self, config: MammutTextConfig): super().__init__(config) self.config = config.text_config self.text_model = MammutMultimodalTransformer(config.text_config) self.text_embed_dim = config.hidden_size self.vision_embed_dim = config.vision_config.hidden_size self.projection_dim = config.projection_dim # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_tokens: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, ) -> Union[MammutPoolingOutput, CLIPTextModelOutput]: return self.text_model( img_embs=image_embs, text_embs=input_ids, output_tokens=output_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, position_ids=position_ids, ) class MammutVisionTransformer(CLIPVisionTransformer): """ Mammut Vision Transformer model. Inherits from CLIPVisionTransformer and initializes the vision model. """ config_class = MammutVisionConfig base_model_prefix = "mammut_vision" def __init__(self, config: MammutVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = CLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.pool_type = config.pool_type def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] elif self.pool_type == 'tok': pooled, tokens = x[:, 0], x[:, 1:] elif self.pool_type == "avg_all": pooled, tokens = x.mean(dim=1), x else: pooled = tokens = x return pooled, tokens def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state pooled_output = last_hidden_state[:, 0, :] if self.config.final_ln_after_pool: pooled, _ = self._global_pool(last_hidden_state) pooled_output = self.post_layernorm(pooled) else: pooled_output = self.post_layernorm(pooled_output) pooled, _ = self._global_pool(pooled_output) pooled_output = pooled return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class MammutVisionModel(CLIPVisionModel): """ Mammut Vision Model. Inherits from CLIPVisionModel and initializes the vision model. """ config_class = MammutVisionConfig base_model_prefix = "mammut_vision" def __init__(self, config: MammutVisionConfig): super().__init__(config) self.config = config self.vision_model = MammutVisionTransformer(config) self.post_init() @dataclass class MammutContrastiveOutput(CLIPOutput): """ Output class for Mammut model in contrastive learning mode. Contains contrastive output: - loss: Loss value if return_loss is True. - logits_per_text: Logits for text inputs. - logits_per_image: Logits for image inputs. - text_embeds: Text embeddings. - image_embeds: Image embeddings. """ loss: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None logits_per_image: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None @dataclass class MammutCaptioningOutput(ModelOutput): """ Output class for Mammut captioning part. Contains: - last_hidden_state: Last hidden state of the text model. - pooler_output: Pooler output of the text model. - hidden_states: Hidden states from the text model. - attentions: Attention weights from the text model. - output_ids: Output tokens from the text model. """ last_hidden_state: torch.FloatTensor = None pooler_output: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None output_ids: Optional[torch.Tensor] = None @dataclass class MammutOutput(ModelOutput): """ Output class for Mammut model. Contains contrastive output: - loss: Loss value if return_loss is True. - logits_per_text: Logits for text inputs. - logits_per_image: Logits for image inputs. - text_embeds: Text embeddings. - image_embeds: Image embeddings. Captioning output: - text_model_output: Output from the text model. - output_ids: Output tokens from the text model. """ loss: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None logits_per_image: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None text_model_output: Optional[MammutCaptioningOutput] = None output_ids: Optional[torch.Tensor] = None # @dataclass # class MammutGenerationOutput(GenerateDecoderOnlyOutput) def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor: """ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566 """ square_tensor = torch.pow(tensor, 2) sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True) normed_tensor = torch.pow(sum_tensor, 0.5) return normed_tensor class MammutModel(CLIPPreTrainedModel): """ Mammut model with text and vision encoders. """ config_class = MammutConfig base_model_prefix = "mammut" def __init__(self, config: MammutConfig): super().__init__(config) self.config = config self.text_model = MammutMultimodalTransformer(config.text_config, output_tokens=config.output_tokens) vision_model = MammutVisionModel._from_config(config.vision_config) self.vision_model = vision_model.vision_model self.text_embed_dim = config.text_config.hidden_size self.vision_embed_dim = config.vision_config.hidden_size self.projection_dim = config.projection_dim self.text_projection = self.text_model.text_projection self.visual_projection = nn.Linear( self.vision_embed_dim, self.projection_dim, bias=False ) if self.projection_dim is not None else None self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) self.map_viz2txt_kv = nn.Parameter(torch.randn( self.config.vision_config.width, self.config.text_config.width )) self.eos_token_id = self.config.text_config.eos_token_id self.bos_token_id = self.config.text_config.bos_token_id self.pad_token_id = self.config.text_config.pad_token_id self.does_full_decoding = config.text_config.does_full_decoding self.context_length = config.text_config.context_length self.vocab_size = config.text_config.vocab_size self.batch_first = config.text_config.batch_first # Initialize weights and apply final processing self.post_init() def get_text_features( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, img_embs: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Get text features from the Mammut model. """ text_model_output = self.text_model( img_embs=img_embs, text_embs=input_ids, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) text_embeds = text_model_output.last_hidden_state text_embeds = self.text_model.final_layer_norm(text_embeds) text_embeds = text_embeds.mean(1) text_embeds = F.normalize(text_embeds, dim=-1) return text_embeds def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, normalize: bool = True, ) -> torch.FloatTensor: """ Get image features from the Mammut model. """ vision_outputs: CLIPVisionModelOutput = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) image_embeds = vision_outputs.pooler_output if self.visual_projection is not None: image_embeds = self.visual_projection(image_embeds) image_embeds = F.normalize(image_embeds, dim=-1) if normalize else image_embeds return image_embeds def _contrastive_forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, output_tokens: Optional[bool] = None, contrastive: Optional[bool] = False, ) -> MammutContrastiveOutput: """ Forward pass for the Mammut model in contrastive learning mode. - **Two-pass learning:** to unify contrastive and next-token prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective. - **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled. - **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task. Return: MammutContrastiveOutput: Contains contrastive output with logits, embeddings, and optional loss. """ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs: CLIPVisionModelOutput = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) # text_model is MammutMultimodalTransformer, which handles text embeddings text_outputs: MammutPoolingOutput = self.text_model( img_embs=None, # No image embeddings in contrastive forward pass for text model text_embs=input_ids, output_tokens=output_tokens, output_attentions=output_attentions, output_hidden_states=output_hidden_states, position_ids=position_ids, ) image_embeds = vision_outputs.pooler_output image_embeds = self.visual_projection(image_embeds) text_embeds = text_outputs.pooler_output pooled, tokens = text_global_pool(text_embeds, text=input_ids) text_embeds = self.text_model.final_layer_norm(text_embeds) text_embeds = text_embeds.mean(1) tokens = self.text_projection(pooled) # Normalize the embeddings image_embeds = image_embeds / _get_vector_norm(image_embeds) text_embeds = text_embeds / _get_vector_norm(text_embeds) # cosine similarity as logits logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device) logits_per_image = logits_per_text.t() loss = None return MammutContrastiveOutput( loss=loss, logits_per_text=logits_per_text, logits_per_image=logits_per_image, text_embeds=text_embeds, image_embeds=image_embeds, ) def _captioning_forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, output_tokens: Optional[bool] = None, ) -> MammutCaptioningOutput: """ Forward pass for the Mammut model in captioning mode. Return: MammutCaptioningOutput: Contains captioning output with last hidden state, pooler output, hidden states, attentions, and output tokens. """ if pixel_values is None: raise ValueError("Pixel values must be provided for captioning.") if input_ids is None: input_ids = torch.ones( (pixel_values.shape[0], self.context_length), dtype=torch.long, device=pixel_values.device ) * self.bos_token_id # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if image_embeds is None: vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs.last_hidden_state image_embeds = image_embeds @ self.map_viz2txt_kv text_model_output = self.text_model( img_embs=image_embeds, # Use image embeddings for captioning text_embs=input_ids, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) text_embeds = text_model_output.last_hidden_state text_embeds = self.text_model.final_layer_norm(text_embeds) logits = self.text_projection(text_embeds) if output_tokens: return MammutCaptioningOutput( last_hidden_state=text_embeds, pooler_output=image_embeds, # Placeholder for pooler output output_ids=logits, # Output tokens from the text model ) return MammutCaptioningOutput( last_hidden_state=text_embeds, pooler_output=image_embeds, # Placeholder for pooler output output_ids=None, # No output tokens in this case ) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, output_tokens: Optional[bool] = False, contrastive_only: Optional[bool] = False, captioning_only: Optional[bool] = False, ) -> MammutOutput: """ Forward pass for the Mammut model. - **Two-pass learning:** to unify contrastive and next-token prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective. - **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled. - **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task. """ # first pass: contrastive task # second pass: captioning task if pixel_values is None and input_ids is None: raise ValueError("Pixel values or input IDs must be provided for captioning.") if output_tokens is None: output_tokens = self.config.output_tokens if output_tokens and not self.config.output_tokens: raise ValueError("Output tokens are not enabled in the configuration.") if output_tokens and pixel_values is None: raise ValueError("Pixel values must be provided if output tokens are enabled.") if output_tokens and input_ids is None: # Only captioning captioning_only = True if input_ids is not None and pixel_values is not None: contrastive_output = self._contrastive_forward( input_ids=input_ids, pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) else: contrastive_output = MammutContrastiveOutput( loss=None, logits_per_text=None, logits_per_image=None, text_embeds=None, image_embeds=None, ) if contrastive_only: # If only contrastive output is needed, return it directly return MammutOutput( loss=contrastive_output.loss, logits_per_text=contrastive_output.logits_per_text, logits_per_image=contrastive_output.logits_per_image, text_embeds=contrastive_output.text_embeds, image_embeds=contrastive_output.image_embeds, ) if captioning_only: # If only captioning output is needed, return it directly text_model_output = self._captioning_forward( input_ids=input_ids, pixel_values=pixel_values, # No pixel values for captioning only attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, output_tokens=output_tokens, ) return MammutOutput( loss=None, # No loss in captioning only mode logits_per_text=None, # No logits in captioning only mode logits_per_image=None, # No logits in captioning only mode text_embeds=text_model_output.last_hidden_state, # Use last hidden state as text embeddings image_embeds=None, # No image embeddings in captioning only mode text_model_output=text_model_output, # Output from the text model output_ids=text_model_output.output_ids, # Output tokens from the text model ) # If both contrastive and captioning outputs are needed, return both text_model_output = self._captioning_forward( input_ids=input_ids, pixel_values=pixel_values, # No pixel values for captioning only attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, output_tokens=output_tokens, ) return MammutOutput( loss=contrastive_output.loss, logits_per_text=contrastive_output.logits_per_text, logits_per_image=contrastive_output.logits_per_image, text_embeds=contrastive_output.text_embeds, image_embeds=contrastive_output.image_embeds, text_model_output=text_model_output, # Output from the text model output_ids=text_model_output.output_ids, # Output tokens from the text model ) @torch.no_grad() def generate( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, max_new_tokens: int = 20, do_sample: bool = False, temperature: float = 1.0, repetition_penalty: float = 1.0, top_p: float = 0, top_k: int = 0, min_seq_len: int = 1, stopping_criteria= None, ) -> GenerateDecoderOnlyOutput: """ Generate captions using the Mammut model. Args: input_ids (torch.LongTensor, optional): Input token IDs for the text model. pixel_values (torch.FloatTensor, optional): Pixel values for the vision model. attention_mask (torch.Tensor, optional): Attention mask for the text model. position_ids (torch.LongTensor, optional): Position IDs for the text model. max_new_tokens (int): Maximum length of the generated sequence. do_sample (bool): Whether to sample from the distribution or take argmax. temperature (float): Temperature for sampling. repetition_penalty (float): Penalty for repetition in sampling. top_p (float): Top-p sampling parameter. top_k (int): Top-k sampling parameter. min_seq_len (int): Minimum sequence length for generation. stopping_criteria: Stopping criteria for generation. Returns: GenerateDecoderOnlyOutput: Contains the generated sequences and logits. """ # This method should implement the generation logic for the Mammut model. if input_ids is None and pixel_values is None: raise ValueError("Input IDs or pixel values must be provided for generation.") if input_ids is None: input_ids = torch.ones( (pixel_values.shape[0], 1), dtype=torch.long, device=pixel_values.device ) * self.bos_token_id if pixel_values is None: raise ValueError("Pixel values must be provided for generation.") self.eval() device = pixel_values.device if pixel_values is not None else input_ids.device if input_ids is None: input_ids = torch.ones( (pixel_values.shape[0], 1), dtype=torch.long, device=device ) * self.bos_token_id eos_token_id = self.eos_token_id if self.eos_token_id is not None else self.text_model.config.eos_token_id logit_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if do_sample: if top_k > 0: logit_warper = LogitsProcessorList( [ TopKLogitsWarper(top_k), ] ) if top_p > 0: logit_warper = LogitsProcessorList( [ TopPLogitsWarper(top_p), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_new_tokens)] stopping_criteria = StoppingCriteriaList( stopping_criteria ) out = input_ids vision_outputs = self.vision_model( pixel_values=pixel_values ) image_embeds = vision_outputs.last_hidden_state with torch.no_grad(): while True: x = out[:, -max_new_tokens:] # Get text features captioning_output = self._captioning_forward( input_ids=x, pixel_values=pixel_values, image_embeds=image_embeds, attention_mask=attention_mask, position_ids=position_ids, output_attentions=False, output_hidden_states=False, interpolate_pos_encoding=False, output_tokens=True, # We want the output tokens ) output_ids = captioning_output.output_ids # Get logits for the next token logits = output_ids[:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == self.pad_token_id) logits = logits[~mask, :] filtered_logits = logit_processor(x[~mask, :], logits) filtered_logits = logit_warper(x[~mask, :], filtered_logits) # Sample or take the argmax of the logits cur_len = out.shape[1] if cur_len >= max_new_tokens: next_token = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id elif do_sample: probs = F.softmax(filtered_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(filtered_logits, dim=-1, keepdim=True) if mask.all(): break # Check if we have reached the end of the sequence or max length if (out.shape[1] >= max_new_tokens) or (next_token == eos_token_id).all(): break # Append the next token to the output sequence out = torch.cat([out, next_token], dim=1) output_ids = out.long() if out.dtype != torch.long else out # If we reach the end of the sequence or max length, break the loop return GenerateDecoderOnlyOutput( logits=logits, sequences=output_ids, # Output tokens from the text model ) AutoModel.register(MammutConfig, MammutModel)