Luisgust commited on
Commit
177d4e8
·
verified ·
1 Parent(s): ab4d201

Create vtoonify/model/encoder/encoders/psp_encoders.py

Browse files
vtoonify/model/encoder/encoders/psp_encoders.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from model.encoder.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
9
+ from model.stylegan.model import EqualLinear
10
+
11
+
12
+ class GradualStyleBlock(Module):
13
+ def __init__(self, in_c, out_c, spatial):
14
+ super(GradualStyleBlock, self).__init__()
15
+ self.out_c = out_c
16
+ self.spatial = spatial
17
+ num_pools = int(np.log2(spatial))
18
+ modules = []
19
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
20
+ nn.LeakyReLU()]
21
+ for i in range(num_pools - 1):
22
+ modules += [
23
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
24
+ nn.LeakyReLU()
25
+ ]
26
+ self.convs = nn.Sequential(*modules)
27
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
28
+
29
+ def forward(self, x):
30
+ x = self.convs(x)
31
+ x = x.view(-1, self.out_c)
32
+ x = self.linear(x)
33
+ return x
34
+
35
+
36
+ class GradualStyleEncoder(Module):
37
+ def __init__(self, num_layers, mode='ir', opts=None):
38
+ super(GradualStyleEncoder, self).__init__()
39
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
40
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
41
+ blocks = get_blocks(num_layers)
42
+ if mode == 'ir':
43
+ unit_module = bottleneck_IR
44
+ elif mode == 'ir_se':
45
+ unit_module = bottleneck_IR_SE
46
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
47
+ BatchNorm2d(64),
48
+ PReLU(64))
49
+ modules = []
50
+ for block in blocks:
51
+ for bottleneck in block:
52
+ modules.append(unit_module(bottleneck.in_channel,
53
+ bottleneck.depth,
54
+ bottleneck.stride))
55
+ self.body = Sequential(*modules)
56
+
57
+ self.styles = nn.ModuleList()
58
+ self.style_count = opts.n_styles
59
+ self.coarse_ind = 3
60
+ self.middle_ind = 7
61
+ for i in range(self.style_count):
62
+ if i < self.coarse_ind:
63
+ style = GradualStyleBlock(512, 512, 16)
64
+ elif i < self.middle_ind:
65
+ style = GradualStyleBlock(512, 512, 32)
66
+ else:
67
+ style = GradualStyleBlock(512, 512, 64)
68
+ self.styles.append(style)
69
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
70
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
71
+
72
+ def _upsample_add(self, x, y):
73
+ '''Upsample and add two feature maps.
74
+ Args:
75
+ x: (Variable) top feature map to be upsampled.
76
+ y: (Variable) lateral feature map.
77
+ Returns:
78
+ (Variable) added feature map.
79
+ Note in PyTorch, when input size is odd, the upsampled feature map
80
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
81
+ maybe not equal to the lateral feature map size.
82
+ e.g.
83
+ original input size: [N,_,15,15] ->
84
+ conv2d feature map size: [N,_,8,8] ->
85
+ upsampled feature map size: [N,_,16,16]
86
+ So we choose bilinear upsample which supports arbitrary output sizes.
87
+ '''
88
+ _, _, H, W = y.size()
89
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
90
+
91
+ def forward(self, x):
92
+ x = self.input_layer(x)
93
+
94
+ latents = []
95
+ modulelist = list(self.body._modules.values())
96
+ for i, l in enumerate(modulelist):
97
+ x = l(x)
98
+ if i == 6:
99
+ c1 = x
100
+ elif i == 20:
101
+ c2 = x
102
+ elif i == 23:
103
+ c3 = x
104
+
105
+ for j in range(self.coarse_ind):
106
+ latents.append(self.styles[j](c3))
107
+
108
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
109
+ for j in range(self.coarse_ind, self.middle_ind):
110
+ latents.append(self.styles[j](p2))
111
+
112
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
113
+ for j in range(self.middle_ind, self.style_count):
114
+ latents.append(self.styles[j](p1))
115
+
116
+ out = torch.stack(latents, dim=1)
117
+ return out
118
+
119
+
120
+ class BackboneEncoderUsingLastLayerIntoW(Module):
121
+ def __init__(self, num_layers, mode='ir', opts=None):
122
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
123
+ print('Using BackboneEncoderUsingLastLayerIntoW')
124
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
125
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
126
+ blocks = get_blocks(num_layers)
127
+ if mode == 'ir':
128
+ unit_module = bottleneck_IR
129
+ elif mode == 'ir_se':
130
+ unit_module = bottleneck_IR_SE
131
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
132
+ BatchNorm2d(64),
133
+ PReLU(64))
134
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
135
+ self.linear = EqualLinear(512, 512, lr_mul=1)
136
+ modules = []
137
+ for block in blocks:
138
+ for bottleneck in block:
139
+ modules.append(unit_module(bottleneck.in_channel,
140
+ bottleneck.depth,
141
+ bottleneck.stride))
142
+ self.body = Sequential(*modules)
143
+
144
+ def forward(self, x):
145
+ x = self.input_layer(x)
146
+ x = self.body(x)
147
+ x = self.output_pool(x)
148
+ x = x.view(-1, 512)
149
+ x = self.linear(x)
150
+ return x
151
+
152
+
153
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
154
+ def __init__(self, num_layers, mode='ir', opts=None):
155
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
156
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
157
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
158
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
159
+ blocks = get_blocks(num_layers)
160
+ if mode == 'ir':
161
+ unit_module = bottleneck_IR
162
+ elif mode == 'ir_se':
163
+ unit_module = bottleneck_IR_SE
164
+ self.n_styles = opts.n_styles
165
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
166
+ BatchNorm2d(64),
167
+ PReLU(64))
168
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
169
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
170
+ Flatten(),
171
+ Linear(512 * 7 * 7, 512))
172
+ self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
173
+ modules = []
174
+ for block in blocks:
175
+ for bottleneck in block:
176
+ modules.append(unit_module(bottleneck.in_channel,
177
+ bottleneck.depth,
178
+ bottleneck.stride))
179
+ self.body = Sequential(*modules)
180
+
181
+ def forward(self, x):
182
+ x = self.input_layer(x)
183
+ x = self.body(x)
184
+ x = self.output_layer_2(x)
185
+ x = self.linear(x)
186
+ x = x.view(-1, self.n_styles, 512)
187
+ return x
188
+