File size: 4,921 Bytes
9157432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import numpy as np
from sklearn.metrics import average_precision_score
import torch
import torch.nn.functional as F
import torch.utils.data
def to_numpy(array):
if isinstance(array, np.ndarray):
return array
if isinstance(array, torch.autograd.Variable):
array = array.data
if array.is_cuda:
array = array.cpu()
return array.numpy()
def squeeze(array):
if not isinstance(array, list) or len(array) > 1:
return array
else: # len(array) == 1:
return array[0]
def unsqueeze(array):
if isinstance(array, list):
return array
else:
return [array]
def is_due(*args):
"""Determines whether to perform an action or not, depending on the epoch.
Used for logging, saving, learning rate decay, etc.
Args:
*args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
due_every (due every due_every epochs)
step, due_every (due every due_every steps)
Returns:
due: boolean: perform action or not
"""
if len(args) == 2 and isinstance(args[1], list):
epoch, due_at = args
due = epoch in due_at
elif len(args) == 3:
epoch, num_epochs, due_every = args
due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
else:
step, due_every = args
due = (due_every > 0) and (step % due_every == 0)
return due
def softmax(w, t=1.0, axis=None):
w = np.array(w) / t
e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
dist = e / np.sum(e, axis=axis, keepdims=True)
return dist
def min_cosine(student, teacher, option, weights=None):
cosine = torch.nn.CosineEmbeddingLoss()
dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda())
if weights is None:
dist = dists.mean()
else:
dist = (dists * weights).mean()
return dist
def distance_metric(student, teacher, option, weights=None):
"""Distance metric to calculate the imitation loss.
Args:
student: batch_size x n_classes
teacher: batch_size x n_classes
option: one of [cosine, l2, l2, kl]
weights: batch_size or float
Returns:
The computed distance metric.
"""
if option == 'cosine':
dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
# dists = 1 - F.cosine_similarity(student, teacher, dim=1)
elif option == 'l2':
dists = (student-teacher.detach()).pow(2).sum(1)
elif option == 'l1':
dists = torch.abs(student-teacher.detach()).sum(1)
elif option == 'kl':
assert weights is None
T = 8
# averaged for each minibatch
dist = F.kl_div(
F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
T * T)
return dist
else:
raise NotImplementedError
if weights is None:
dist = dists.mean()
else:
dist = (dists * weights).mean()
return dist
def get_segments(input, timestep):
"""Split entire input into segments of length timestep.
Args:
input: 1 x total_length x n_frames x ...
timestep: the timestamp.
Returns:
input: concatenated video segments
start_indices: indices of the segments
"""
assert input.size(0) == 1, 'Test time, batch_size must be 1'
input.squeeze_(dim=0)
# Find overlapping segments
length = input.size()[0]
step = timestep // 2
num_segments = (length - timestep) // step + 1
start_indices = (np.arange(num_segments) * step).tolist()
if length % step > 0:
start_indices.append(length - timestep)
# Get the segments
segments = []
for s in start_indices:
segment = input[s: (s + timestep)].unsqueeze(0)
segments.append(segment)
input = torch.cat(segments, dim=0)
return input, start_indices
def get_stats(logit, label):
'''
Calculate the accuracy.
'''
logit = to_numpy(logit)
label = to_numpy(label)
pred = np.argmax(logit, 1)
acc = np.sum(pred == label)/label.shape[0]
return acc, pred, label
def get_stats_detection(logit, label, n_classes=52):
'''
Calculate the accuracy and average precisions.
'''
logit = to_numpy(logit)
label = to_numpy(label)
scores = softmax(logit, axis=1)
pred = np.argmax(logit, 1)
length = label.shape[0]
acc = np.sum(pred == label)/length
keep_bg = label == 0
acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
ratio_bg = np.sum(keep_bg)/length
keep_action = label != 0
acc_action = np.sum(
pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
# Average precision
y_true = np.zeros((len(label), n_classes))
y_true[np.arange(len(label)), label] = 1
acc = np.sum(pred == label)/label.shape[0]
aps = average_precision_score(y_true, scores, average=None)
aps = list(filter(lambda x: not np.isnan(x), aps))
ap = np.mean(aps)
return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
def info(text):
print('\033[94m' + text + '\033[0m')
def warn(text):
print('\033[93m' + text + '\033[0m')
def err(text):
print('\033[91m' + text + '\033[0m')
|