import torch import torch.nn as nn from torchvision import transforms from timm import create_model import timm from .model_embedder import HybridEmbed import os torch.hub.set_dir('./cache') os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" os.environ['TORCH_HOME'] = '/models' class Encoder(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) ) def forward(self, x): return self.features(x) class Decoder(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)), nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)), nn.ReLU(inplace=True), nn.ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2)), nn.ReLU(inplace=True), nn.ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2)), nn.ReLU(inplace=True) ) def forward(self, x): return self.features(x) class GenConViTED(nn.Module): # def __init__(self, config, pretrained=True): def __init__(self, pretrained=True): super(GenConViTED, self).__init__() self.encoder = Encoder() self.decoder = Decoder() # self.backbone = timm.create_model(config['model']['backbone'], pretrained=pretrained) # model_path = './convnext_tiny.pth' self.backbone = timm.create_model('convnext_tiny', pretrained=True) # self.backbone.load_state_dict(torch.load(model_path)) # self.embedder = timm.create_model(config['model']['embedder'], pretrained=pretrained) # embedder_path = '../models/swin_tiny_patch4_window7_224.pth' self.embedder = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True) # self.embedder.load_state_dict(torch.load(embedder_path)) # self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768) self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=224, embed_dim=768) self.num_features = self.backbone.head.fc.out_features * 2 self.fc = nn.Linear(self.num_features, self.num_features//4) self.fc2 = nn.Linear(self.num_features//4, 2) self.relu = nn.GELU() def forward(self, images): encimg = self.encoder(images) decimg = self.decoder(encimg) x1 = self.backbone(decimg) x2 = self.backbone(images) x = torch.cat((x1,x2), dim=1) x = self.fc2(self.relu(self.fc(self.relu(x)))) return x