|
import torch |
|
from torch import nn |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, config): |
|
super(Generator, self).__init__() |
|
|
|
self.latent_dim = config["latent_dim"] |
|
self.ngf = config["ngf"] |
|
self.nc = config["nc"] |
|
|
|
self.n_classes = config["num_classes"] |
|
self.embed_dim = config["embed_dim"] |
|
|
|
|
|
self.label_embed = nn.Embedding(self.n_classes, self.embed_dim) |
|
|
|
|
|
self.main = nn.Sequential( |
|
|
|
|
|
nn.ConvTranspose2d(self.latent_dim + self.embed_dim, self.ngf * 4, |
|
kernel_size=7, stride=1, padding=0, bias=False), |
|
nn.BatchNorm2d(self.ngf * 4), |
|
nn.ReLU(True), |
|
|
|
|
|
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, |
|
kernel_size=4, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(self.ngf * 2), |
|
nn.ReLU(True), |
|
|
|
|
|
nn.ConvTranspose2d(self.ngf * 2, self.ngf, |
|
kernel_size=4, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(self.ngf), |
|
nn.ReLU(True), |
|
|
|
|
|
nn.ConvTranspose2d(self.ngf, self.nc, kernel_size=3, |
|
stride=1, padding=1, bias=False), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, noise, labels): |
|
|
|
label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3) |
|
|
|
gen_input = torch.cat([noise, label_embedding], dim=1) |
|
return self.main(gen_input) |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, config): |
|
super(Discriminator, self).__init__() |
|
|
|
self.ndf = config["ndf"] |
|
self.nc = config["nc"] |
|
|
|
self.n_classes = config["num_classes"] |
|
self.embed_dim = config["embed_dim"] |
|
|
|
|
|
self.label_embed = nn.Embedding(self.n_classes, self.embed_dim) |
|
|
|
|
|
self.main = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(self.nc + self.embed_dim, self.ndf, kernel_size=4, stride=2, padding=1, bias=False), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
|
|
nn.Conv2d(self.ndf, self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(self.ndf * 2), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
|
|
nn.Conv2d(self.ndf * 2, 1, kernel_size=7, stride=1, padding=0, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, img, labels): |
|
|
|
label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3) |
|
|
|
label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3)) |
|
|
|
|
|
d_in = torch.cat((img, label_embedding), dim=1) |
|
|
|
return self.main(d_in).view(-1, 1).squeeze(1) |
|
|