# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from functools import partial import torch.nn.functional as F from typing import Optional, Tuple, Type from .modeling import ImageEncoderViT def vit_encoder_l(): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size encoder_embed_dim=1024 encoder_depth=24 encoder_num_heads=16 encoder_global_attn_indexes=[5, 11, 17, 23] image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, use_abs_pos = False, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) return image_encoder def vit_encoder_b(): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size encoder_embed_dim=768 encoder_depth=12 encoder_num_heads=12 encoder_global_attn_indexes=[2, 5, 8, 11] image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, use_abs_pos = False, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) return image_encoder def vit_encoder_h(): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size encoder_embed_dim=1280 encoder_depth=32 encoder_num_heads=16 encoder_global_attn_indexes=[7, 15, 23, 31] image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, use_abs_pos = False, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) return image_encoder