File size: 3,377 Bytes
f9561b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import sys
from os import path as osp
import argparse
import warnings
import torch
import numpy as np
from PIL import Image
from detectron2.config import instantiate, LazyConfig
sys.path.append(osp.dirname(osp.dirname(__file__)))
from utils import *
warnings.simplefilter(action="ignore", category=FutureWarning)
def do_test(cfg, model, use_dark_inference=False):
val_loader = instantiate(cfg.dataloader.val)
model.train(False)
AUC = []
min_dist = []
avg_dist = []
with torch.no_grad():
for data in val_loader:
val_gaze_heatmap_pred, _ = model(data)
val_gaze_heatmap_pred = (
val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
)
# go through each data point and record AUC, min dist, avg dist
for b_i in range(len(val_gaze_heatmap_pred)):
# remove padding and recover valid ground truth points
valid_gaze = data["gazes"][b_i]
valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2)
# AUC: area under curve of ROC
multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i])
if use_dark_inference:
pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
else:
pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
norm_p = [
pred_x / val_gaze_heatmap_pred[b_i].shape[-2],
pred_y / val_gaze_heatmap_pred[b_i].shape[-1],
]
scaled_heatmap = np.array(
Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
data["imsize"][b_i],
resample=Image.BILINEAR,
)
)
auc_score = auc(scaled_heatmap, multi_hot)
AUC.append(auc_score)
# min distance: minimum among all possible pairs of <ground truth point, predicted point>
all_distances = []
for gt_gaze in valid_gaze:
all_distances.append(L2_dist(gt_gaze, norm_p))
min_dist.append(min(all_distances))
# average distance: distance between the predicted point and human average point
mean_gt_gaze = torch.mean(valid_gaze, 0)
avg_distance = L2_dist(mean_gt_gaze, norm_p)
avg_dist.append(avg_distance)
print("|AUC |min dist|avg dist|")
print(
"|{:.4f}|{:.4f} |{:.4f} |".format(
torch.mean(torch.tensor(AUC)),
torch.mean(torch.tensor(min_dist)),
torch.mean(torch.tensor(avg_dist)),
)
)
def main(args):
cfg = LazyConfig.load(args.config_file)
model: torch.Module = instantiate(cfg.model)
model.load_state_dict(torch.load(args.model_weights)["model"])
model.to(cfg.train.device)
do_test(cfg, model, use_dark_inference=args.use_dark_inference)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_file", type=str, help="config file")
parser.add_argument(
"--model_weights",
type=str,
help="model weights",
)
parser.add_argument("--use_dark_inference", action="store_true")
args = parser.parse_args()
main(args)
|