cGAN-MNIST / c_gan.py
hussamalafandi's picture
Upload folder using huggingface_hub
691c76f verified
import torch
from torch import nn
class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
self.latent_dim = config["latent_dim"]
self.ngf = config["ngf"]
self.nc = config["nc"]
self.n_classes = config["num_classes"]
self.embed_dim = config["embed_dim"]
# Label embedding: maps labels to vectors of size embed_dim.
self.label_embed = nn.Embedding(self.n_classes, self.embed_dim)
# DCGAN generator architecture
self.main = nn.Sequential(
# Combine noise and label embedding -> output shape: (latent_dim + embed_dim, 1, 1)
# upscale to 7x7 with ngf*4 channels.
nn.ConvTranspose2d(self.latent_dim + self.embed_dim, self.ngf * 4,
kernel_size=7, stride=1, padding=0, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# 7x7 -> 14x14
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# 14x14 -> 28x28.
nn.ConvTranspose2d(self.ngf * 2, self.ngf,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(self.ngf),
nn.ReLU(True),
# Final layer: convert to 1 channel, preserving 28x28.
nn.ConvTranspose2d(self.ngf, self.nc, kernel_size=3,
stride=1, padding=1, bias=False),
nn.Tanh()
)
def forward(self, noise, labels):
# Embed labels and reshape to (batch, embed_dim, 1, 1)
label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3)
# Concatenate noise and embedded labels along the channel dimension.
gen_input = torch.cat([noise, label_embedding], dim=1)
return self.main(gen_input)
class Discriminator(nn.Module):
def __init__(self, config):
super(Discriminator, self).__init__()
self.ndf = config["ndf"]
self.nc = config["nc"]
self.n_classes = config["num_classes"]
self.embed_dim = config["embed_dim"]
# Label embedding: maps labels to vectors of size embed_dim.
self.label_embed = nn.Embedding(self.n_classes, self.embed_dim)
# DCGAN discriminator architecture
self.main = nn.Sequential(
# Input: (nc + embed_dim) x 28 x 28
nn.Conv2d(self.nc + self.embed_dim, self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf) x 14 x 14
nn.Conv2d(self.ndf, self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(self.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# State: (ndf*2) x 1 x 1
nn.Conv2d(self.ndf * 2, 1, kernel_size=7, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, img, labels):
# Embed the labels and replicate them spatially
label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3)
# Assume img is of shape (batch, nc, H, W)
label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
# Concatenate the image with the label embedding
d_in = torch.cat((img, label_embedding), dim=1)
return self.main(d_in).view(-1, 1).squeeze(1)