bryandts commited on
Commit
e05c407
·
verified ·
1 Parent(s): 8610a1c

Update discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +1 -16
discriminatorModel.py CHANGED
@@ -1,22 +1,7 @@
1
 
2
  import torch
3
  import torch.nn as nn
4
-
5
- # The Embedding model
6
- class Embedding(nn.Module):
7
- def __init__(self, size_in, size_out):
8
- super(Embedding, self).__init__()
9
- self.text_embedding = nn.Sequential(
10
- nn.Linear(size_in, size_out),
11
- nn.BatchNorm1d(1),
12
- nn.LeakyReLU(0.2, inplace=True)
13
- )
14
-
15
- def forward(self, x, text):
16
- embed_out = self.text_embedding(text)
17
- embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
18
- out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
19
- return out
20
 
21
  # The Discriminator model
22
  class Discriminator(nn.Module):
 
1
 
2
  import torch
3
  import torch.nn as nn
4
+ from discriminatorEmbedding import Embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # The Discriminator model
7
  class Discriminator(nn.Module):