|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from timm import create_model |
|
from genconvit.config import load_config |
|
from .model_embedder import HybridEmbed |
|
import os |
|
config = load_config() |
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
class Encoder(nn.Module): |
|
|
|
def __init__(self, latent_dims=4): |
|
super(Encoder, self).__init__() |
|
|
|
self.features = nn.Sequential( |
|
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(num_features=16), |
|
nn.LeakyReLU(), |
|
|
|
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(num_features=32), |
|
nn.LeakyReLU(), |
|
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(num_features=64), |
|
nn.LeakyReLU(), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
|
nn.BatchNorm2d(num_features=128), |
|
nn.LeakyReLU() |
|
) |
|
|
|
self.latent_dims = latent_dims |
|
self.fc1 = nn.Linear(128*14*14, 256) |
|
self.fc2 = nn.Linear(256, 128) |
|
self.mu = nn.Linear(128*14*14, self.latent_dims) |
|
self.var = nn.Linear(128*14*14, self.latent_dims) |
|
|
|
self.kl = 0 |
|
self.kl_weight = 0.5 |
|
self.relu = nn.LeakyReLU() |
|
|
|
def reparameterize(self, x): |
|
|
|
std = torch.exp(0.5*self.mu(x)) |
|
eps = torch.randn_like(std) |
|
z = eps * std + self.mu(x) |
|
|
|
return z, std |
|
|
|
def forward(self, x): |
|
x = self.features(x) |
|
x = torch.flatten(x, start_dim=1) |
|
|
|
mu = self.mu(x) |
|
var = self.var(x) |
|
z,_ = self.reparameterize(x) |
|
self.kl = self.kl_weight*torch.mean(-0.5*torch.sum(1+var - mu**2 - var.exp(), dim=1), dim=0) |
|
|
|
return z |
|
|
|
class Decoder(nn.Module): |
|
|
|
def __init__(self, latent_dims=4): |
|
super(Decoder, self).__init__() |
|
|
|
self.features = nn.Sequential( |
|
nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2), |
|
nn.LeakyReLU(), |
|
|
|
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), |
|
nn.LeakyReLU(), |
|
|
|
nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2), |
|
nn.LeakyReLU(), |
|
|
|
nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2), |
|
nn.LeakyReLU() |
|
) |
|
|
|
self.latent_dims = latent_dims |
|
|
|
self.unflatten = nn.Unflatten(dim=1, unflattened_size=(256, 7, 7)) |
|
|
|
def forward(self, x): |
|
x = self.unflatten(x) |
|
x = self.features(x) |
|
return x |
|
|
|
class GenConViTVAE(nn.Module): |
|
def __init__(self, config, pretrained=True): |
|
super(GenConViTVAE, self).__init__() |
|
self.latent_dims = config['model']['latent_dims'] |
|
self.encoder = Encoder(self.latent_dims) |
|
self.decoder = Decoder(self.latent_dims) |
|
self.embedder = create_model(config['model']['embedder'], pretrained=True) |
|
self.convnext_backbone = create_model(config['model']['backbone'], pretrained=True, num_classes=1000, drop_path_rate=0, head_init_scale=1.0) |
|
self.convnext_backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768) |
|
self.num_feature = self.convnext_backbone.head.fc.out_features * 2 |
|
|
|
self.fc = nn.Linear(self.num_feature, self.num_feature//4) |
|
self.fc3 = nn.Linear(self.num_feature//2, self.num_feature//4) |
|
self.fc2 = nn.Linear(self.num_feature//4, config['num_classes']) |
|
self.relu = nn.ReLU() |
|
self.resize = transforms.Resize((224,224), antialias=True) |
|
|
|
def forward(self, x): |
|
z = self.encoder(x) |
|
x_hat = self.decoder(z) |
|
|
|
x1 = self.convnext_backbone(x) |
|
x2 = self.convnext_backbone(x_hat) |
|
x = torch.cat((x1,x2), dim=1) |
|
x = self.fc2(self.relu(self.fc(self.relu(x)))) |
|
|
|
return x, self.resize(x_hat) |