|
import numpy as np |
|
from sklearn.metrics import average_precision_score |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.data |
|
|
|
def to_numpy(array): |
|
if isinstance(array, np.ndarray): |
|
return array |
|
if isinstance(array, torch.autograd.Variable): |
|
array = array.data |
|
if array.is_cuda: |
|
array = array.cpu() |
|
|
|
return array.numpy() |
|
|
|
|
|
def squeeze(array): |
|
if not isinstance(array, list) or len(array) > 1: |
|
return array |
|
else: |
|
return array[0] |
|
|
|
|
|
def unsqueeze(array): |
|
if isinstance(array, list): |
|
return array |
|
else: |
|
return [array] |
|
|
|
|
|
def is_due(*args): |
|
"""Determines whether to perform an action or not, depending on the epoch. |
|
Used for logging, saving, learning rate decay, etc. |
|
|
|
Args: |
|
*args: epoch, due_at (due at epoch due_at) epoch, num_epochs, |
|
due_every (due every due_every epochs) |
|
step, due_every (due every due_every steps) |
|
Returns: |
|
due: boolean: perform action or not |
|
""" |
|
if len(args) == 2 and isinstance(args[1], list): |
|
epoch, due_at = args |
|
due = epoch in due_at |
|
elif len(args) == 3: |
|
epoch, num_epochs, due_every = args |
|
due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs) |
|
else: |
|
step, due_every = args |
|
due = (due_every > 0) and (step % due_every == 0) |
|
|
|
return due |
|
|
|
|
|
def softmax(w, t=1.0, axis=None): |
|
w = np.array(w) / t |
|
e = np.exp(w - np.amax(w, axis=axis, keepdims=True)) |
|
dist = e / np.sum(e, axis=axis, keepdims=True) |
|
return dist |
|
|
|
|
|
def min_cosine(student, teacher, option, weights=None): |
|
cosine = torch.nn.CosineEmbeddingLoss() |
|
dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda()) |
|
if weights is None: |
|
dist = dists.mean() |
|
else: |
|
dist = (dists * weights).mean() |
|
|
|
return dist |
|
|
|
|
|
|
|
def distance_metric(student, teacher, option, weights=None): |
|
"""Distance metric to calculate the imitation loss. |
|
|
|
Args: |
|
student: batch_size x n_classes |
|
teacher: batch_size x n_classes |
|
option: one of [cosine, l2, l2, kl] |
|
weights: batch_size or float |
|
|
|
Returns: |
|
The computed distance metric. |
|
""" |
|
if option == 'cosine': |
|
dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1) |
|
|
|
elif option == 'l2': |
|
dists = (student-teacher.detach()).pow(2).sum(1) |
|
elif option == 'l1': |
|
dists = torch.abs(student-teacher.detach()).sum(1) |
|
elif option == 'kl': |
|
assert weights is None |
|
T = 8 |
|
|
|
dist = F.kl_div( |
|
F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * ( |
|
T * T) |
|
return dist |
|
else: |
|
raise NotImplementedError |
|
|
|
if weights is None: |
|
dist = dists.mean() |
|
else: |
|
dist = (dists * weights).mean() |
|
|
|
return dist |
|
|
|
|
|
def get_segments(input, timestep): |
|
"""Split entire input into segments of length timestep. |
|
|
|
Args: |
|
input: 1 x total_length x n_frames x ... |
|
timestep: the timestamp. |
|
|
|
Returns: |
|
input: concatenated video segments |
|
start_indices: indices of the segments |
|
""" |
|
assert input.size(0) == 1, 'Test time, batch_size must be 1' |
|
|
|
input.squeeze_(dim=0) |
|
|
|
length = input.size()[0] |
|
step = timestep // 2 |
|
num_segments = (length - timestep) // step + 1 |
|
start_indices = (np.arange(num_segments) * step).tolist() |
|
if length % step > 0: |
|
start_indices.append(length - timestep) |
|
|
|
|
|
segments = [] |
|
for s in start_indices: |
|
segment = input[s: (s + timestep)].unsqueeze(0) |
|
segments.append(segment) |
|
input = torch.cat(segments, dim=0) |
|
return input, start_indices |
|
|
|
def get_stats(logit, label): |
|
''' |
|
Calculate the accuracy. |
|
''' |
|
logit = to_numpy(logit) |
|
label = to_numpy(label) |
|
|
|
pred = np.argmax(logit, 1) |
|
acc = np.sum(pred == label)/label.shape[0] |
|
|
|
return acc, pred, label |
|
|
|
|
|
def get_stats_detection(logit, label, n_classes=52): |
|
''' |
|
Calculate the accuracy and average precisions. |
|
''' |
|
logit = to_numpy(logit) |
|
label = to_numpy(label) |
|
scores = softmax(logit, axis=1) |
|
|
|
pred = np.argmax(logit, 1) |
|
length = label.shape[0] |
|
acc = np.sum(pred == label)/length |
|
|
|
keep_bg = label == 0 |
|
acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0] |
|
ratio_bg = np.sum(keep_bg)/length |
|
|
|
keep_action = label != 0 |
|
acc_action = np.sum( |
|
pred[keep_action] == label[keep_action]) / label[keep_action].shape[0] |
|
|
|
|
|
y_true = np.zeros((len(label), n_classes)) |
|
y_true[np.arange(len(label)), label] = 1 |
|
acc = np.sum(pred == label)/label.shape[0] |
|
aps = average_precision_score(y_true, scores, average=None) |
|
aps = list(filter(lambda x: not np.isnan(x), aps)) |
|
ap = np.mean(aps) |
|
|
|
return ap, acc, acc_bg, acc_action, ratio_bg, pred, label |
|
|
|
|
|
def info(text): |
|
print('\033[94m' + text + '\033[0m') |
|
|
|
|
|
def warn(text): |
|
print('\033[93m' + text + '\033[0m') |
|
|
|
|
|
def err(text): |
|
print('\033[91m' + text + '\033[0m') |
|
|