mimosa-ai / genconvit /model_embedder.py
vivekk3's picture
Upload folder using huggingface_hub
9c4b01e verified
import torch
import torch.nn as nn
import os
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = (feature_size, feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x