import torch import torch.nn as nn from torchvision import transforms from timm import create_model from genconvit.config import load_config from .model_embedder import HybridEmbed import os config = load_config() torch.hub.set_dir('./cache') os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" class Encoder(nn.Module): def __init__(self, latent_dims=4): super(Encoder, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(num_features=16), nn.LeakyReLU(), nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(num_features=128), nn.LeakyReLU() ) self.latent_dims = latent_dims self.fc1 = nn.Linear(128*14*14, 256) self.fc2 = nn.Linear(256, 128) self.mu = nn.Linear(128*14*14, self.latent_dims) self.var = nn.Linear(128*14*14, self.latent_dims) self.kl = 0 self.kl_weight = 0.5#0.00025 self.relu = nn.LeakyReLU() def reparameterize(self, x): # https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py std = torch.exp(0.5*self.mu(x)) eps = torch.randn_like(std) z = eps * std + self.mu(x) return z, std def forward(self, x): x = self.features(x) x = torch.flatten(x, start_dim=1) mu = self.mu(x) var = self.var(x) z,_ = self.reparameterize(x) self.kl = self.kl_weight*torch.mean(-0.5*torch.sum(1+var - mu**2 - var.exp(), dim=1), dim=0) return z class Decoder(nn.Module): def __init__(self, latent_dims=4): super(Decoder, self).__init__() self.features = nn.Sequential( nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2), nn.LeakyReLU(), nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), nn.LeakyReLU(), nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2), nn.LeakyReLU(), nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2), nn.LeakyReLU() ) self.latent_dims = latent_dims self.unflatten = nn.Unflatten(dim=1, unflattened_size=(256, 7, 7)) def forward(self, x): x = self.unflatten(x) x = self.features(x) return x class GenConViTVAE(nn.Module): def __init__(self, config, pretrained=True): super(GenConViTVAE, self).__init__() self.latent_dims = config['model']['latent_dims'] self.encoder = Encoder(self.latent_dims) self.decoder = Decoder(self.latent_dims) self.embedder = create_model(config['model']['embedder'], pretrained=True) self.convnext_backbone = create_model(config['model']['backbone'], pretrained=True, num_classes=1000, drop_path_rate=0, head_init_scale=1.0) self.convnext_backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768) self.num_feature = self.convnext_backbone.head.fc.out_features * 2 self.fc = nn.Linear(self.num_feature, self.num_feature//4) self.fc3 = nn.Linear(self.num_feature//2, self.num_feature//4) self.fc2 = nn.Linear(self.num_feature//4, config['num_classes']) self.relu = nn.ReLU() self.resize = transforms.Resize((224,224), antialias=True) def forward(self, x): z = self.encoder(x) x_hat = self.decoder(z) x1 = self.convnext_backbone(x) x2 = self.convnext_backbone(x_hat) x = torch.cat((x1,x2), dim=1) x = self.fc2(self.relu(self.fc(self.relu(x)))) return x, self.resize(x_hat)