File size: 3,540 Bytes
691c76f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch import nn


class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()

        self.latent_dim = config["latent_dim"]
        self.ngf = config["ngf"]
        self.nc = config["nc"]

        self.n_classes = config["num_classes"]
        self.embed_dim = config["embed_dim"]

        # Label embedding: maps labels to vectors of size embed_dim.
        self.label_embed = nn.Embedding(self.n_classes, self.embed_dim)

        # DCGAN generator architecture
        self.main = nn.Sequential(
            # Combine noise and label embedding -> output shape: (latent_dim + embed_dim, 1, 1)
            # upscale to 7x7 with ngf*4 channels.
            nn.ConvTranspose2d(self.latent_dim + self.embed_dim, self.ngf * 4,
                               kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),

            # 7x7 -> 14x14
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),

            # 14x14 -> 28x28.
            nn.ConvTranspose2d(self.ngf * 2, self.ngf,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),

            # Final layer: convert to 1 channel, preserving 28x28.
            nn.ConvTranspose2d(self.ngf, self.nc, kernel_size=3,
                               stride=1, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Embed labels and reshape to (batch, embed_dim, 1, 1)
        label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3)
        # Concatenate noise and embedded labels along the channel dimension.
        gen_input = torch.cat([noise, label_embedding], dim=1)
        return self.main(gen_input)


class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()

        self.ndf = config["ndf"]
        self.nc = config["nc"]

        self.n_classes = config["num_classes"]
        self.embed_dim = config["embed_dim"]

        # Label embedding: maps labels to vectors of size embed_dim.
        self.label_embed = nn.Embedding(self.n_classes, self.embed_dim)

        # DCGAN discriminator architecture
        self.main = nn.Sequential(

            # Input: (nc + embed_dim) x 28 x 28
            nn.Conv2d(self.nc + self.embed_dim, self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # State: (ndf) x 14 x 14
            nn.Conv2d(self.ndf, self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # State: (ndf*2) x 1 x 1
            nn.Conv2d(self.ndf * 2, 1, kernel_size=7, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Embed the labels and replicate them spatially
        label_embedding = self.label_embed(labels).unsqueeze(2).unsqueeze(3)
        # Assume img is of shape (batch, nc, H, W)
        label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
        
        # Concatenate the image with the label embedding
        d_in = torch.cat((img, label_embedding), dim=1)

        return self.main(d_in).view(-1, 1).squeeze(1)