# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from functools import partial from dataclasses import dataclass from typing import Callable, Dict, Optional from timm.models.layers import to_2tuple from fairseq.tasks import FairseqTask from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed from .base import ( D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias, MaskSeed, ) from .modules import ( BlockEncoder, Decoder2d, FixedPositionalEncoder, TransformerDecoder, EncDecTransformerDecoder, ) from examples.data2vec.data.modality import Modality @dataclass class D2vImageConfig(D2vModalityConfig): type: Modality = Modality.IMAGE input_size: int = 224 in_chans: int = 3 patch_size: int = 16 embed_dim: int = 768 alibi_dims: int = 2 alibi_distance: str = "manhattan" fixed_positions: bool = True transformer_decoder: bool = False enc_dec_transformer: bool = False class ImageEncoder(ModalitySpecificEncoder): modality_cfg: D2vImageConfig def __init__( self, modality_cfg: D2vImageConfig, embed_dim: int, make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList], norm_layer: Callable[[int], nn.LayerNorm], layer_norm_first: bool, alibi_biases: Dict, task: Optional[FairseqTask], ): img_size = to_2tuple(modality_cfg.input_size) patch_size = to_2tuple(modality_cfg.patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) local_encoder = PatchEmbed( modality_cfg.input_size, modality_cfg.patch_size, modality_cfg.in_chans, modality_cfg.embed_dim, ) w = local_encoder.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) if modality_cfg.embed_dim != embed_dim: local_encoder = nn.Sequential( local_encoder, nn.Linear(modality_cfg.embed_dim, embed_dim), ) project_features = nn.Identity() pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim), requires_grad=False ) side_n = int(num_patches ** 0.5) emb = get_2d_sincos_pos_embed( pos_embed.shape[-1], side_n, cls_token=False, ) pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0)) fixed_positional_encoder = ( FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None ) dpr = np.linspace( modality_cfg.start_drop_path_rate, modality_cfg.end_drop_path_rate, modality_cfg.prenet_depth, ) context_encoder = BlockEncoder( nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), norm_layer(embed_dim) if not layer_norm_first else None, layer_norm_first, modality_cfg.prenet_layerdrop, modality_cfg.prenet_dropout, ) if modality_cfg.transformer_decoder: if modality_cfg.enc_dec_transformer: decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim) else: dec_enc = BlockEncoder( nn.ModuleList( make_block(0, modality_cfg.decoder.decoder_dim, 8) for _ in range(modality_cfg.decoder.decoder_layers) ), None, layer_norm_first, 0, 0, ) decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc) else: decoder = ( Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n) if modality_cfg.decoder is not None else None ) alibi_bias_fn = partial( get_alibi_bias, alibi_biases=alibi_biases, heads=modality_cfg.num_alibi_heads, dims=modality_cfg.alibi_dims, distance=modality_cfg.alibi_distance, ) super().__init__( modality_cfg=modality_cfg, embed_dim=embed_dim, local_encoder=local_encoder, project_features=project_features, fixed_positional_encoder=fixed_positional_encoder, relative_positional_encoder=None, context_encoder=context_encoder, decoder=decoder, get_alibi_bias=alibi_bias_fn, ) def reset_parameters(self): super().reset_parameters() if self.decoder is not None: self.decoder.reset_parameters() @torch.no_grad() def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.modality_cfg.patch_size h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum("nchpwq->nhwpqc", x) x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) return x @torch.no_grad() def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.modality_cfg.patch_size h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def compute_mask( self, x, padding_mask, mask_seed: Optional[MaskSeed], apply, shape=None, precomputed_mask=None, ): mlen = self.modality_cfg.mask_length if mlen <= 1: return super().compute_mask( x, padding_mask, mask_seed, apply, precomputed_mask ) if precomputed_mask is not None: mask = precomputed_mask else: from fairseq.data.data_utils import compute_block_mask_2d if shape is not None: B, L, D = shape else: B, L, D = x.shape mask = compute_block_mask_2d( shape=(B, L), mask_prob=self.modality_cfg.mask_prob, mask_length=self.modality_cfg.mask_length, mask_prob_adjust=self.modality_cfg.mask_prob_adjust, inverse_mask=self.modality_cfg.inverse_mask, require_same_masks=True, mask_dropout=self.modality_cfg.mask_dropout, ) mask_info = self.make_maskinfo(x, mask, shape) if apply: x = self.apply_mask(x, mask_info) return x, mask_info def decoder_input(self, x, mask_info): if ( not self.modality_cfg.transformer_decoder or not self.modality_cfg.enc_dec_transformer ): return super().decoder_input(x, mask_info) inp_drop = self.modality_cfg.decoder.input_dropout if inp_drop > 0: x = F.dropout(x, inp_drop, training=self.training, inplace=True) kv = x[:, self.modality_cfg.num_extra_tokens :] assert self.fixed_positional_encoder is not None pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1) mask = mask_info.mask.bool() if self.modality_cfg.decoder.add_positions_all: kv = kv + pos[~mask].view(kv.shape) q = pos[mask].view(x.size(0), -1, x.size(-1)) return q, kv