Luisgust commited on
Commit
9895fdb
·
verified ·
1 Parent(s): 1b8f176

Create vtoonify/model/stylegan/lpips/networks_basic.py

Browse files
vtoonify/model/stylegan/lpips/networks_basic.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import sys
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.init as init
7
+ from torch.autograd import Variable
8
+ import numpy as np
9
+ from pdb import set_trace as st
10
+ from skimage import color
11
+ from IPython import embed
12
+ from model.stylegan.lpips import pretrained_networks as pn
13
+
14
+ import model.stylegan.lpips as util
15
+
16
+ def spatial_average(in_tens, keepdim=True):
17
+ return in_tens.mean([2,3],keepdim=keepdim)
18
+
19
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
20
+ in_H = in_tens.shape[2]
21
+ scale_factor = 1.*out_H/in_H
22
+
23
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
24
+
25
+ # Learned perceptual metric
26
+ class PNetLin(nn.Module):
27
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
28
+ super(PNetLin, self).__init__()
29
+
30
+ self.pnet_type = pnet_type
31
+ self.pnet_tune = pnet_tune
32
+ self.pnet_rand = pnet_rand
33
+ self.spatial = spatial
34
+ self.lpips = lpips
35
+ self.version = version
36
+ self.scaling_layer = ScalingLayer()
37
+
38
+ if(self.pnet_type in ['vgg','vgg16']):
39
+ net_type = pn.vgg16
40
+ self.chns = [64,128,256,512,512]
41
+ elif(self.pnet_type=='alex'):
42
+ net_type = pn.alexnet
43
+ self.chns = [64,192,384,256,256]
44
+ elif(self.pnet_type=='squeeze'):
45
+ net_type = pn.squeezenet
46
+ self.chns = [64,128,256,384,384,512,512]
47
+ self.L = len(self.chns)
48
+
49
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
50
+
51
+ if(lpips):
52
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
53
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
54
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
55
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
56
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
57
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
58
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
59
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
60
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
61
+ self.lins+=[self.lin5,self.lin6]
62
+
63
+ def forward(self, in0, in1, retPerLayer=False):
64
+ # v0.0 - original release had a bug, where input was not scaled
65
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
66
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
67
+ feats0, feats1, diffs = {}, {}, {}
68
+
69
+ for kk in range(self.L):
70
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
71
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
72
+
73
+ if(self.lpips):
74
+ if(self.spatial):
75
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
76
+ else:
77
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
78
+ else:
79
+ if(self.spatial):
80
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
81
+ else:
82
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
83
+
84
+ val = res[0]
85
+ for l in range(1,self.L):
86
+ val += res[l]
87
+
88
+ if(retPerLayer):
89
+ return (val, res)
90
+ else:
91
+ return val
92
+
93
+ class ScalingLayer(nn.Module):
94
+ def __init__(self):
95
+ super(ScalingLayer, self).__init__()
96
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
97
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
98
+
99
+ def forward(self, inp):
100
+ return (inp - self.shift) / self.scale
101
+
102
+
103
+ class NetLinLayer(nn.Module):
104
+ ''' A single linear layer which does a 1x1 conv '''
105
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
106
+ super(NetLinLayer, self).__init__()
107
+
108
+ layers = [nn.Dropout(),] if(use_dropout) else []
109
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
110
+ self.model = nn.Sequential(*layers)
111
+
112
+
113
+ class Dist2LogitLayer(nn.Module):
114
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
115
+ def __init__(self, chn_mid=32, use_sigmoid=True):
116
+ super(Dist2LogitLayer, self).__init__()
117
+
118
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
119
+ layers += [nn.LeakyReLU(0.2,True),]
120
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
121
+ layers += [nn.LeakyReLU(0.2,True),]
122
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
123
+ if(use_sigmoid):
124
+ layers += [nn.Sigmoid(),]
125
+ self.model = nn.Sequential(*layers)
126
+
127
+ def forward(self,d0,d1,eps=0.1):
128
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
129
+
130
+ class BCERankingLoss(nn.Module):
131
+ def __init__(self, chn_mid=32):
132
+ super(BCERankingLoss, self).__init__()
133
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
134
+ # self.parameters = list(self.net.parameters())
135
+ self.loss = torch.nn.BCELoss()
136
+
137
+ def forward(self, d0, d1, judge):
138
+ per = (judge+1.)/2.
139
+ self.logit = self.net.forward(d0,d1)
140
+ return self.loss(self.logit, per)
141
+
142
+ # L2, DSSIM metrics
143
+ class FakeNet(nn.Module):
144
+ def __init__(self, use_gpu=True, colorspace='Lab'):
145
+ super(FakeNet, self).__init__()
146
+ self.use_gpu = use_gpu
147
+ self.colorspace=colorspace
148
+
149
+ class L2(FakeNet):
150
+
151
+ def forward(self, in0, in1, retPerLayer=None):
152
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
153
+
154
+ if(self.colorspace=='RGB'):
155
+ (N,C,X,Y) = in0.size()
156
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
157
+ return value
158
+ elif(self.colorspace=='Lab'):
159
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
160
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
161
+ ret_var = Variable( torch.Tensor((value,) ) )
162
+ if(self.use_gpu):
163
+ ret_var = ret_var.cuda()
164
+ return ret_var
165
+
166
+ class DSSIM(FakeNet):
167
+
168
+ def forward(self, in0, in1, retPerLayer=None):
169
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
170
+
171
+ if(self.colorspace=='RGB'):
172
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
173
+ elif(self.colorspace=='Lab'):
174
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
175
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
176
+ ret_var = Variable( torch.Tensor((value,) ) )
177
+ if(self.use_gpu):
178
+ ret_var = ret_var.cuda()
179
+ return ret_var
180
+
181
+ def print_network(net):
182
+ num_params = 0
183
+ for param in net.parameters():
184
+ num_params += param.numel()
185
+ print('Network',net)
186
+ print('Total number of parameters: %d' % num_params)