|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from timm import create_model |
|
import timm |
|
from .model_embedder import HybridEmbed |
|
import os |
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
os.environ['TORCH_HOME'] = '/models' |
|
class Encoder(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.features = nn.Sequential( |
|
nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), |
|
|
|
nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), |
|
|
|
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0), |
|
|
|
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), |
|
|
|
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0) |
|
) |
|
|
|
def forward(self, x): |
|
return self.features(x) |
|
|
|
class Decoder(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.features = nn.Sequential( |
|
nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2)), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2)), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.features(x) |
|
|
|
class GenConViTED(nn.Module): |
|
|
|
def __init__(self, pretrained=True): |
|
|
|
super(GenConViTED, self).__init__() |
|
self.encoder = Encoder() |
|
self.decoder = Decoder() |
|
|
|
|
|
self.backbone = timm.create_model('convnext_tiny', pretrained=True) |
|
|
|
|
|
|
|
|
|
self.embedder = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True) |
|
|
|
|
|
|
|
self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=224, embed_dim=768) |
|
|
|
|
|
self.num_features = self.backbone.head.fc.out_features * 2 |
|
self.fc = nn.Linear(self.num_features, self.num_features//4) |
|
self.fc2 = nn.Linear(self.num_features//4, 2) |
|
self.relu = nn.GELU() |
|
|
|
def forward(self, images): |
|
|
|
encimg = self.encoder(images) |
|
decimg = self.decoder(encimg) |
|
|
|
x1 = self.backbone(decimg) |
|
x2 = self.backbone(images) |
|
|
|
x = torch.cat((x1,x2), dim=1) |
|
|
|
x = self.fc2(self.relu(self.fc(self.relu(x)))) |
|
|
|
return x |