Luisgust commited on
Commit
1b8f176
·
verified ·
1 Parent(s): 7afa197

Create vtoonify/model/stylegan/lpips/dist_model.py

Browse files
vtoonify/model/stylegan/lpips/dist_model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import sys
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import os
8
+ from collections import OrderedDict
9
+ from torch.autograd import Variable
10
+ import itertools
11
+ from model.stylegan.lpips.base_model import BaseModel
12
+ from scipy.ndimage import zoom
13
+ import fractions
14
+ import functools
15
+ import skimage.transform
16
+ from tqdm import tqdm
17
+
18
+ from IPython import embed
19
+
20
+ from model.stylegan.lpips import networks_basic as networks
21
+ import model.stylegan.lpips as util
22
+
23
+ class DistModel(BaseModel):
24
+ def name(self):
25
+ return self.model_name
26
+
27
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
28
+ use_gpu=True, printNet=False, spatial=False,
29
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
30
+ '''
31
+ INPUTS
32
+ model - ['net-lin'] for linearly calibrated network
33
+ ['net'] for off-the-shelf network
34
+ ['L2'] for L2 distance in Lab colorspace
35
+ ['SSIM'] for ssim in RGB colorspace
36
+ net - ['squeeze','alex','vgg']
37
+ model_path - if None, will look in weights/[NET_NAME].pth
38
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
39
+ use_gpu - bool - whether or not to use a GPU
40
+ printNet - bool - whether or not to print network architecture out
41
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
42
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
43
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
44
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
45
+ is_train - bool - [True] for training mode
46
+ lr - float - initial learning rate
47
+ beta1 - float - initial momentum term for adam
48
+ version - 0.1 for latest, 0.0 was original (with a bug)
49
+ gpu_ids - int array - [0] by default, gpus to use
50
+ '''
51
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
52
+
53
+ self.model = model
54
+ self.net = net
55
+ self.is_train = is_train
56
+ self.spatial = spatial
57
+ self.gpu_ids = gpu_ids
58
+ self.model_name = '%s [%s]'%(model,net)
59
+
60
+ if(self.model == 'net-lin'): # pretrained net + linear layer
61
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
62
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
63
+ kw = {}
64
+ if not use_gpu:
65
+ kw['map_location'] = 'cpu'
66
+ if(model_path is None):
67
+ import inspect
68
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
69
+
70
+ if(not is_train):
71
+ print('Loading model from: %s'%model_path)
72
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
73
+
74
+ elif(self.model=='net'): # pretrained network
75
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
76
+ elif(self.model in ['L2','l2']):
77
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
78
+ self.model_name = 'L2'
79
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
80
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
81
+ self.model_name = 'SSIM'
82
+ else:
83
+ raise ValueError("Model [%s] not recognized." % self.model)
84
+
85
+ self.parameters = list(self.net.parameters())
86
+
87
+ if self.is_train: # training mode
88
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
89
+ self.rankLoss = networks.BCERankingLoss()
90
+ self.parameters += list(self.rankLoss.net.parameters())
91
+ self.lr = lr
92
+ self.old_lr = lr
93
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
94
+ else: # test mode
95
+ self.net.eval()
96
+
97
+ if(use_gpu):
98
+ self.net.to(gpu_ids[0])
99
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
100
+ if(self.is_train):
101
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
102
+
103
+ if(printNet):
104
+ print('---------- Networks initialized -------------')
105
+ networks.print_network(self.net)
106
+ print('-----------------------------------------------')
107
+
108
+ def forward(self, in0, in1, retPerLayer=False):
109
+ ''' Function computes the distance between image patches in0 and in1
110
+ INPUTS
111
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
112
+ OUTPUT
113
+ computed distances between in0 and in1
114
+ '''
115
+
116
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
117
+
118
+ # ***** TRAINING FUNCTIONS *****
119
+ def optimize_parameters(self):
120
+ self.forward_train()
121
+ self.optimizer_net.zero_grad()
122
+ self.backward_train()
123
+ self.optimizer_net.step()
124
+ self.clamp_weights()
125
+
126
+ def clamp_weights(self):
127
+ for module in self.net.modules():
128
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
129
+ module.weight.data = torch.clamp(module.weight.data,min=0)
130
+
131
+ def set_input(self, data):
132
+ self.input_ref = data['ref']
133
+ self.input_p0 = data['p0']
134
+ self.input_p1 = data['p1']
135
+ self.input_judge = data['judge']
136
+
137
+ if(self.use_gpu):
138
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
139
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
140
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
141
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
142
+
143
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
144
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
145
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
146
+
147
+ def forward_train(self): # run forward pass
148
+ # print(self.net.module.scaling_layer.shift)
149
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
150
+
151
+ self.d0 = self.forward(self.var_ref, self.var_p0)
152
+ self.d1 = self.forward(self.var_ref, self.var_p1)
153
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
154
+
155
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
156
+
157
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
158
+
159
+ return self.loss_total
160
+
161
+ def backward_train(self):
162
+ torch.mean(self.loss_total).backward()
163
+
164
+ def compute_accuracy(self,d0,d1,judge):
165
+ ''' d0, d1 are Variables, judge is a Tensor '''
166
+ d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
167
+ judge_per = judge.cpu().numpy().flatten()
168
+ return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
169
+
170
+ def get_current_errors(self):
171
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
172
+ ('acc_r', self.acc_r)])
173
+
174
+ for key in retDict.keys():
175
+ retDict[key] = np.mean(retDict[key])
176
+
177
+ return retDict
178
+
179
+ def get_current_visuals(self):
180
+ zoom_factor = 256/self.var_ref.data.size()[2]
181
+
182
+ ref_img = util.tensor2im(self.var_ref.data)
183
+ p0_img = util.tensor2im(self.var_p0.data)
184
+ p1_img = util.tensor2im(self.var_p1.data)
185
+
186
+ ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
187
+ p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
188
+ p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
189
+
190
+ return OrderedDict([('ref', ref_img_vis),
191
+ ('p0', p0_img_vis),
192
+ ('p1', p1_img_vis)])
193
+
194
+ def save(self, path, label):
195
+ if(self.use_gpu):
196
+ self.save_network(self.net.module, path, '', label)
197
+ else:
198
+ self.save_network(self.net, path, '', label)
199
+ self.save_network(self.rankLoss.net, path, 'rank', label)
200
+
201
+ def update_learning_rate(self,nepoch_decay):
202
+ lrd = self.lr / nepoch_decay
203
+ lr = self.old_lr - lrd
204
+
205
+ for param_group in self.optimizer_net.param_groups:
206
+ param_group['lr'] = lr
207
+
208
+ print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
209
+ self.old_lr = lr
210
+
211
+ def score_2afc_dataset(data_loader, func, name=''):
212
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
213
+ distance function 'func' in dataset 'data_loader'
214
+ INPUTS
215
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
216
+ func - callable distance function - calling d=func(in0,in1) should take 2
217
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
218
+ OUTPUTS
219
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
220
+ [1] - dictionary with following elements
221
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
222
+ gts - N array in [0,1], preferred patch selected by human evaluators
223
+ (closer to "0" for left patch p0, "1" for right patch p1,
224
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
225
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
226
+ CONSTS
227
+ N - number of test triplets in data_loader
228
+ '''
229
+
230
+ d0s = []
231
+ d1s = []
232
+ gts = []
233
+
234
+ for data in tqdm(data_loader.load_data(), desc=name):
235
+ d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
236
+ d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
237
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
238
+
239
+ d0s = np.array(d0s)
240
+ d1s = np.array(d1s)
241
+ gts = np.array(gts)
242
+ scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
243
+
244
+ return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
245
+
246
+ def score_jnd_dataset(data_loader, func, name=''):
247
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
248
+ INPUTS
249
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
250
+ func - callable distance function - calling d=func(in0,in1) should take 2
251
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
252
+ OUTPUTS
253
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
254
+ [1] - dictionary with following elements
255
+ ds - N array containing distances between two patches shown to human evaluator
256
+ sames - N array containing fraction of people who thought the two patches were identical
257
+ CONSTS
258
+ N - number of test triplets in data_loader
259
+ '''
260
+
261
+ ds = []
262
+ gts = []
263
+
264
+ for data in tqdm(data_loader.load_data(), desc=name):
265
+ ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
266
+ gts+=data['same'].cpu().numpy().flatten().tolist()
267
+
268
+ sames = np.array(gts)
269
+ ds = np.array(ds)
270
+
271
+ sorted_inds = np.argsort(ds)
272
+ ds_sorted = ds[sorted_inds]
273
+ sames_sorted = sames[sorted_inds]
274
+
275
+ TPs = np.cumsum(sames_sorted)
276
+ FPs = np.cumsum(1-sames_sorted)
277
+ FNs = np.sum(sames_sorted)-TPs
278
+
279
+ precs = TPs/(TPs+FPs)
280
+ recs = TPs/(TPs+FNs)
281
+ score = util.voc_ap(recs,precs)
282
+
283
+ return(score, dict(ds=ds,sames=sames))