Create vtoonify/model/stylegan/lpips/dist_model.py
Browse files
vtoonify/model/stylegan/lpips/dist_model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import os
|
8 |
+
from collections import OrderedDict
|
9 |
+
from torch.autograd import Variable
|
10 |
+
import itertools
|
11 |
+
from model.stylegan.lpips.base_model import BaseModel
|
12 |
+
from scipy.ndimage import zoom
|
13 |
+
import fractions
|
14 |
+
import functools
|
15 |
+
import skimage.transform
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from IPython import embed
|
19 |
+
|
20 |
+
from model.stylegan.lpips import networks_basic as networks
|
21 |
+
import model.stylegan.lpips as util
|
22 |
+
|
23 |
+
class DistModel(BaseModel):
|
24 |
+
def name(self):
|
25 |
+
return self.model_name
|
26 |
+
|
27 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
28 |
+
use_gpu=True, printNet=False, spatial=False,
|
29 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
30 |
+
'''
|
31 |
+
INPUTS
|
32 |
+
model - ['net-lin'] for linearly calibrated network
|
33 |
+
['net'] for off-the-shelf network
|
34 |
+
['L2'] for L2 distance in Lab colorspace
|
35 |
+
['SSIM'] for ssim in RGB colorspace
|
36 |
+
net - ['squeeze','alex','vgg']
|
37 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
38 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
39 |
+
use_gpu - bool - whether or not to use a GPU
|
40 |
+
printNet - bool - whether or not to print network architecture out
|
41 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
42 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
43 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
44 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
45 |
+
is_train - bool - [True] for training mode
|
46 |
+
lr - float - initial learning rate
|
47 |
+
beta1 - float - initial momentum term for adam
|
48 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
49 |
+
gpu_ids - int array - [0] by default, gpus to use
|
50 |
+
'''
|
51 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
52 |
+
|
53 |
+
self.model = model
|
54 |
+
self.net = net
|
55 |
+
self.is_train = is_train
|
56 |
+
self.spatial = spatial
|
57 |
+
self.gpu_ids = gpu_ids
|
58 |
+
self.model_name = '%s [%s]'%(model,net)
|
59 |
+
|
60 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
61 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
62 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
63 |
+
kw = {}
|
64 |
+
if not use_gpu:
|
65 |
+
kw['map_location'] = 'cpu'
|
66 |
+
if(model_path is None):
|
67 |
+
import inspect
|
68 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
69 |
+
|
70 |
+
if(not is_train):
|
71 |
+
print('Loading model from: %s'%model_path)
|
72 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
73 |
+
|
74 |
+
elif(self.model=='net'): # pretrained network
|
75 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
76 |
+
elif(self.model in ['L2','l2']):
|
77 |
+
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
78 |
+
self.model_name = 'L2'
|
79 |
+
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
80 |
+
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
81 |
+
self.model_name = 'SSIM'
|
82 |
+
else:
|
83 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
84 |
+
|
85 |
+
self.parameters = list(self.net.parameters())
|
86 |
+
|
87 |
+
if self.is_train: # training mode
|
88 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
89 |
+
self.rankLoss = networks.BCERankingLoss()
|
90 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
91 |
+
self.lr = lr
|
92 |
+
self.old_lr = lr
|
93 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
94 |
+
else: # test mode
|
95 |
+
self.net.eval()
|
96 |
+
|
97 |
+
if(use_gpu):
|
98 |
+
self.net.to(gpu_ids[0])
|
99 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
100 |
+
if(self.is_train):
|
101 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
102 |
+
|
103 |
+
if(printNet):
|
104 |
+
print('---------- Networks initialized -------------')
|
105 |
+
networks.print_network(self.net)
|
106 |
+
print('-----------------------------------------------')
|
107 |
+
|
108 |
+
def forward(self, in0, in1, retPerLayer=False):
|
109 |
+
''' Function computes the distance between image patches in0 and in1
|
110 |
+
INPUTS
|
111 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
112 |
+
OUTPUT
|
113 |
+
computed distances between in0 and in1
|
114 |
+
'''
|
115 |
+
|
116 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
117 |
+
|
118 |
+
# ***** TRAINING FUNCTIONS *****
|
119 |
+
def optimize_parameters(self):
|
120 |
+
self.forward_train()
|
121 |
+
self.optimizer_net.zero_grad()
|
122 |
+
self.backward_train()
|
123 |
+
self.optimizer_net.step()
|
124 |
+
self.clamp_weights()
|
125 |
+
|
126 |
+
def clamp_weights(self):
|
127 |
+
for module in self.net.modules():
|
128 |
+
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
129 |
+
module.weight.data = torch.clamp(module.weight.data,min=0)
|
130 |
+
|
131 |
+
def set_input(self, data):
|
132 |
+
self.input_ref = data['ref']
|
133 |
+
self.input_p0 = data['p0']
|
134 |
+
self.input_p1 = data['p1']
|
135 |
+
self.input_judge = data['judge']
|
136 |
+
|
137 |
+
if(self.use_gpu):
|
138 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
139 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
140 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
141 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
142 |
+
|
143 |
+
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
144 |
+
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
145 |
+
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
146 |
+
|
147 |
+
def forward_train(self): # run forward pass
|
148 |
+
# print(self.net.module.scaling_layer.shift)
|
149 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
150 |
+
|
151 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
152 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
153 |
+
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
154 |
+
|
155 |
+
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
156 |
+
|
157 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
158 |
+
|
159 |
+
return self.loss_total
|
160 |
+
|
161 |
+
def backward_train(self):
|
162 |
+
torch.mean(self.loss_total).backward()
|
163 |
+
|
164 |
+
def compute_accuracy(self,d0,d1,judge):
|
165 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
166 |
+
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
167 |
+
judge_per = judge.cpu().numpy().flatten()
|
168 |
+
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
169 |
+
|
170 |
+
def get_current_errors(self):
|
171 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
172 |
+
('acc_r', self.acc_r)])
|
173 |
+
|
174 |
+
for key in retDict.keys():
|
175 |
+
retDict[key] = np.mean(retDict[key])
|
176 |
+
|
177 |
+
return retDict
|
178 |
+
|
179 |
+
def get_current_visuals(self):
|
180 |
+
zoom_factor = 256/self.var_ref.data.size()[2]
|
181 |
+
|
182 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
183 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
184 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
185 |
+
|
186 |
+
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
187 |
+
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
+
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
+
|
190 |
+
return OrderedDict([('ref', ref_img_vis),
|
191 |
+
('p0', p0_img_vis),
|
192 |
+
('p1', p1_img_vis)])
|
193 |
+
|
194 |
+
def save(self, path, label):
|
195 |
+
if(self.use_gpu):
|
196 |
+
self.save_network(self.net.module, path, '', label)
|
197 |
+
else:
|
198 |
+
self.save_network(self.net, path, '', label)
|
199 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
200 |
+
|
201 |
+
def update_learning_rate(self,nepoch_decay):
|
202 |
+
lrd = self.lr / nepoch_decay
|
203 |
+
lr = self.old_lr - lrd
|
204 |
+
|
205 |
+
for param_group in self.optimizer_net.param_groups:
|
206 |
+
param_group['lr'] = lr
|
207 |
+
|
208 |
+
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
209 |
+
self.old_lr = lr
|
210 |
+
|
211 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
212 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
213 |
+
distance function 'func' in dataset 'data_loader'
|
214 |
+
INPUTS
|
215 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
216 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
217 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
218 |
+
OUTPUTS
|
219 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
220 |
+
[1] - dictionary with following elements
|
221 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
222 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
223 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
224 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
225 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
226 |
+
CONSTS
|
227 |
+
N - number of test triplets in data_loader
|
228 |
+
'''
|
229 |
+
|
230 |
+
d0s = []
|
231 |
+
d1s = []
|
232 |
+
gts = []
|
233 |
+
|
234 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
235 |
+
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
236 |
+
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
237 |
+
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
238 |
+
|
239 |
+
d0s = np.array(d0s)
|
240 |
+
d1s = np.array(d1s)
|
241 |
+
gts = np.array(gts)
|
242 |
+
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
243 |
+
|
244 |
+
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
245 |
+
|
246 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
247 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
248 |
+
INPUTS
|
249 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
250 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
251 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
252 |
+
OUTPUTS
|
253 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
254 |
+
[1] - dictionary with following elements
|
255 |
+
ds - N array containing distances between two patches shown to human evaluator
|
256 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
257 |
+
CONSTS
|
258 |
+
N - number of test triplets in data_loader
|
259 |
+
'''
|
260 |
+
|
261 |
+
ds = []
|
262 |
+
gts = []
|
263 |
+
|
264 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
265 |
+
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
266 |
+
gts+=data['same'].cpu().numpy().flatten().tolist()
|
267 |
+
|
268 |
+
sames = np.array(gts)
|
269 |
+
ds = np.array(ds)
|
270 |
+
|
271 |
+
sorted_inds = np.argsort(ds)
|
272 |
+
ds_sorted = ds[sorted_inds]
|
273 |
+
sames_sorted = sames[sorted_inds]
|
274 |
+
|
275 |
+
TPs = np.cumsum(sames_sorted)
|
276 |
+
FPs = np.cumsum(1-sames_sorted)
|
277 |
+
FNs = np.sum(sames_sorted)-TPs
|
278 |
+
|
279 |
+
precs = TPs/(TPs+FPs)
|
280 |
+
recs = TPs/(TPs+FNs)
|
281 |
+
score = util.voc_ap(recs,precs)
|
282 |
+
|
283 |
+
return(score, dict(ds=ds,sames=sames))
|