File size: 3,710 Bytes
9c4b01e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
import torch.nn as nn
from torchvision import transforms
from timm import create_model
import timm
from .model_embedder import HybridEmbed
import os
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
os.environ['TORCH_HOME'] = '/models'
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
)
def forward(self, x):
return self.features(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2)),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2)),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.features(x)
class GenConViTED(nn.Module):
# def __init__(self, config, pretrained=True):
def __init__(self, pretrained=True):
super(GenConViTED, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
# self.backbone = timm.create_model(config['model']['backbone'], pretrained=pretrained)
# model_path = './convnext_tiny.pth'
self.backbone = timm.create_model('convnext_tiny', pretrained=True)
# self.backbone.load_state_dict(torch.load(model_path))
# self.embedder = timm.create_model(config['model']['embedder'], pretrained=pretrained)
# embedder_path = '../models/swin_tiny_patch4_window7_224.pth'
self.embedder = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
# self.embedder.load_state_dict(torch.load(embedder_path))
# self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768)
self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=224, embed_dim=768)
self.num_features = self.backbone.head.fc.out_features * 2
self.fc = nn.Linear(self.num_features, self.num_features//4)
self.fc2 = nn.Linear(self.num_features//4, 2)
self.relu = nn.GELU()
def forward(self, images):
encimg = self.encoder(images)
decimg = self.decoder(encimg)
x1 = self.backbone(decimg)
x2 = self.backbone(images)
x = torch.cat((x1,x2), dim=1)
x = self.fc2(self.relu(self.fc(self.relu(x))))
return x |