Luisgust commited on
Commit
6bd676b
·
verified ·
1 Parent(s): d7439a9

Create vtoonify/model/bisenet/model.py

Browse files
Files changed (1) hide show
  1. vtoonify/model/bisenet/model.py +285 -0
vtoonify/model/bisenet/model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/python
3
+ # -*- encoding: utf-8 -*-
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+
11
+ from model.bisenet.resnet import Resnet18
12
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
13
+
14
+
15
+ class ConvBNReLU(nn.Module):
16
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
17
+ super(ConvBNReLU, self).__init__()
18
+ self.conv = nn.Conv2d(in_chan,
19
+ out_chan,
20
+ kernel_size = ks,
21
+ stride = stride,
22
+ padding = padding,
23
+ bias = False)
24
+ self.bn = nn.BatchNorm2d(out_chan)
25
+ self.init_weight()
26
+
27
+ def forward(self, x):
28
+ x = self.conv(x)
29
+ x = F.relu(self.bn(x))
30
+ return x
31
+
32
+ def init_weight(self):
33
+ for ly in self.children():
34
+ if isinstance(ly, nn.Conv2d):
35
+ nn.init.kaiming_normal_(ly.weight, a=1)
36
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
37
+
38
+ class BiSeNetOutput(nn.Module):
39
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
40
+ super(BiSeNetOutput, self).__init__()
41
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
42
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
43
+ self.init_weight()
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ x = self.conv_out(x)
48
+ return x
49
+
50
+ def init_weight(self):
51
+ for ly in self.children():
52
+ if isinstance(ly, nn.Conv2d):
53
+ nn.init.kaiming_normal_(ly.weight, a=1)
54
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
55
+
56
+ def get_params(self):
57
+ wd_params, nowd_params = [], []
58
+ for name, module in self.named_modules():
59
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
60
+ wd_params.append(module.weight)
61
+ if not module.bias is None:
62
+ nowd_params.append(module.bias)
63
+ elif isinstance(module, nn.BatchNorm2d):
64
+ nowd_params += list(module.parameters())
65
+ return wd_params, nowd_params
66
+
67
+
68
+ class AttentionRefinementModule(nn.Module):
69
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
70
+ super(AttentionRefinementModule, self).__init__()
71
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
72
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
73
+ self.bn_atten = nn.BatchNorm2d(out_chan)
74
+ self.sigmoid_atten = nn.Sigmoid()
75
+ self.init_weight()
76
+
77
+ def forward(self, x):
78
+ feat = self.conv(x)
79
+ atten = F.avg_pool2d(feat, feat.size()[2:])
80
+ atten = self.conv_atten(atten)
81
+ atten = self.bn_atten(atten)
82
+ atten = self.sigmoid_atten(atten)
83
+ out = torch.mul(feat, atten)
84
+ return out
85
+
86
+ def init_weight(self):
87
+ for ly in self.children():
88
+ if isinstance(ly, nn.Conv2d):
89
+ nn.init.kaiming_normal_(ly.weight, a=1)
90
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
91
+
92
+
93
+ class ContextPath(nn.Module):
94
+ def __init__(self, *args, **kwargs):
95
+ super(ContextPath, self).__init__()
96
+ self.resnet = Resnet18()
97
+ self.arm16 = AttentionRefinementModule(256, 128)
98
+ self.arm32 = AttentionRefinementModule(512, 128)
99
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
101
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
102
+
103
+ self.init_weight()
104
+
105
+ def forward(self, x):
106
+ H0, W0 = x.size()[2:]
107
+ feat8, feat16, feat32 = self.resnet(x)
108
+ H8, W8 = feat8.size()[2:]
109
+ H16, W16 = feat16.size()[2:]
110
+ H32, W32 = feat32.size()[2:]
111
+
112
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
113
+ avg = self.conv_avg(avg)
114
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
115
+
116
+ feat32_arm = self.arm32(feat32)
117
+ feat32_sum = feat32_arm + avg_up
118
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
119
+ feat32_up = self.conv_head32(feat32_up)
120
+
121
+ feat16_arm = self.arm16(feat16)
122
+ feat16_sum = feat16_arm + feat32_up
123
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
124
+ feat16_up = self.conv_head16(feat16_up)
125
+
126
+ return feat8, feat16_up, feat32_up # x8, x8, x16
127
+
128
+ def init_weight(self):
129
+ for ly in self.children():
130
+ if isinstance(ly, nn.Conv2d):
131
+ nn.init.kaiming_normal_(ly.weight, a=1)
132
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
133
+
134
+ def get_params(self):
135
+ wd_params, nowd_params = [], []
136
+ for name, module in self.named_modules():
137
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
138
+ wd_params.append(module.weight)
139
+ if not module.bias is None:
140
+ nowd_params.append(module.bias)
141
+ elif isinstance(module, nn.BatchNorm2d):
142
+ nowd_params += list(module.parameters())
143
+ return wd_params, nowd_params
144
+
145
+
146
+ ### This is not used, since I replace this with the resnet feature with the same size
147
+ class SpatialPath(nn.Module):
148
+ def __init__(self, *args, **kwargs):
149
+ super(SpatialPath, self).__init__()
150
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
151
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
153
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
154
+ self.init_weight()
155
+
156
+ def forward(self, x):
157
+ feat = self.conv1(x)
158
+ feat = self.conv2(feat)
159
+ feat = self.conv3(feat)
160
+ feat = self.conv_out(feat)
161
+ return feat
162
+
163
+ def init_weight(self):
164
+ for ly in self.children():
165
+ if isinstance(ly, nn.Conv2d):
166
+ nn.init.kaiming_normal_(ly.weight, a=1)
167
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
168
+
169
+ def get_params(self):
170
+ wd_params, nowd_params = [], []
171
+ for name, module in self.named_modules():
172
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
173
+ wd_params.append(module.weight)
174
+ if not module.bias is None:
175
+ nowd_params.append(module.bias)
176
+ elif isinstance(module, nn.BatchNorm2d):
177
+ nowd_params += list(module.parameters())
178
+ return wd_params, nowd_params
179
+
180
+
181
+ class FeatureFusionModule(nn.Module):
182
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
183
+ super(FeatureFusionModule, self).__init__()
184
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
185
+ self.conv1 = nn.Conv2d(out_chan,
186
+ out_chan//4,
187
+ kernel_size = 1,
188
+ stride = 1,
189
+ padding = 0,
190
+ bias = False)
191
+ self.conv2 = nn.Conv2d(out_chan//4,
192
+ out_chan,
193
+ kernel_size = 1,
194
+ stride = 1,
195
+ padding = 0,
196
+ bias = False)
197
+ self.relu = nn.ReLU(inplace=True)
198
+ self.sigmoid = nn.Sigmoid()
199
+ self.init_weight()
200
+
201
+ def forward(self, fsp, fcp):
202
+ fcat = torch.cat([fsp, fcp], dim=1)
203
+ feat = self.convblk(fcat)
204
+ atten = F.avg_pool2d(feat, feat.size()[2:])
205
+ atten = self.conv1(atten)
206
+ atten = self.relu(atten)
207
+ atten = self.conv2(atten)
208
+ atten = self.sigmoid(atten)
209
+ feat_atten = torch.mul(feat, atten)
210
+ feat_out = feat_atten + feat
211
+ return feat_out
212
+
213
+ def init_weight(self):
214
+ for ly in self.children():
215
+ if isinstance(ly, nn.Conv2d):
216
+ nn.init.kaiming_normal_(ly.weight, a=1)
217
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
218
+
219
+ def get_params(self):
220
+ wd_params, nowd_params = [], []
221
+ for name, module in self.named_modules():
222
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
223
+ wd_params.append(module.weight)
224
+ if not module.bias is None:
225
+ nowd_params.append(module.bias)
226
+ elif isinstance(module, nn.BatchNorm2d):
227
+ nowd_params += list(module.parameters())
228
+ return wd_params, nowd_params
229
+
230
+
231
+ class BiSeNet(nn.Module):
232
+ def __init__(self, n_classes, *args, **kwargs):
233
+ super(BiSeNet, self).__init__()
234
+ self.cp = ContextPath()
235
+ ## here self.sp is deleted
236
+ self.ffm = FeatureFusionModule(256, 256)
237
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
238
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
239
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
240
+ self.init_weight()
241
+
242
+ def forward(self, x):
243
+ H, W = x.size()[2:]
244
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
245
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
246
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
247
+
248
+ feat_out = self.conv_out(feat_fuse)
249
+ feat_out16 = self.conv_out16(feat_cp8)
250
+ feat_out32 = self.conv_out32(feat_cp16)
251
+
252
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
254
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
255
+ return feat_out, feat_out16, feat_out32
256
+
257
+ def init_weight(self):
258
+ for ly in self.children():
259
+ if isinstance(ly, nn.Conv2d):
260
+ nn.init.kaiming_normal_(ly.weight, a=1)
261
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
262
+
263
+ def get_params(self):
264
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
265
+ for name, child in self.named_children():
266
+ child_wd_params, child_nowd_params = child.get_params()
267
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
268
+ lr_mul_wd_params += child_wd_params
269
+ lr_mul_nowd_params += child_nowd_params
270
+ else:
271
+ wd_params += child_wd_params
272
+ nowd_params += child_nowd_params
273
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
274
+
275
+
276
+ if __name__ == "__main__":
277
+ net = BiSeNet(19)
278
+ net.cuda()
279
+ net.eval()
280
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
281
+ out, out16, out32 = net(in_ten)
282
+ print(out.shape)
283
+
284
+ net.get_params()
285
+