File size: 3,317 Bytes
f9561b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
from os import path as osp
import argparse
import warnings
import torch
import numpy as np
from PIL import Image
from detectron2.config import instantiate, LazyConfig

sys.path.append(osp.dirname(osp.dirname(__file__)))
from utils import *


warnings.simplefilter(action="ignore", category=FutureWarning)


def do_test(cfg, model, use_dark_inference=False):
    val_loader = instantiate(cfg.dataloader.val)

    model.train(False)
    AUC = []
    dist = []
    inout_gt = []
    inout_pred = []
    with torch.no_grad():
        for data in val_loader:
            val_gaze_heatmap_pred, val_gaze_inout_pred = model(data)
            val_gaze_heatmap_pred = (
                val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
            )
            val_gaze_inout_pred = val_gaze_inout_pred.cpu().detach().numpy()

            # go through each data point and record AUC, dist, ap
            for b_i in range(len(val_gaze_heatmap_pred)):
                auc_batch = []
                dist_batch = []
                if data["gaze_inouts"][b_i]:
                    # remove padding and recover valid ground truth points
                    valid_gaze = data["gazes"][b_i]
                    # AUC: area under curve of ROC
                    multi_hot = data["heatmaps"][b_i]
                    multi_hot = (multi_hot > 0).float().numpy()
                    if use_dark_inference:
                        pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
                    else:
                        pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
                    norm_p = [
                        pred_x / val_gaze_heatmap_pred[b_i].shape[-1],
                        pred_y / val_gaze_heatmap_pred[b_i].shape[-2],
                    ]
                    scaled_heatmap = np.array(
                        Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
                            (64, 64),
                            resample=Image.Resampling.BILINEAR,
                        )
                    )
                    auc_score = auc(scaled_heatmap, multi_hot)
                    auc_batch.append(auc_score)
                    dist_batch.append(L2_dist(valid_gaze.numpy(), norm_p))
                AUC.extend(auc_batch)
                dist.extend(dist_batch)
            inout_gt.extend(data["gaze_inouts"].cpu().numpy())
            inout_pred.extend(val_gaze_inout_pred)

    print("|AUC   |dist    |AP     |")
    print(
        "|{:.4f}|{:.4f}  |{:.4f}  |".format(
            torch.mean(torch.tensor(AUC)),
            torch.mean(torch.tensor(dist)),
            ap(inout_gt, inout_pred),
        )
    )


def main(args):
    cfg = LazyConfig.load(args.config_file)
    model: torch.Module = instantiate(cfg.model)
    model.load_state_dict(torch.load(args.model_weights)["model"])
    model.to(cfg.train.device)
    do_test(cfg, model, use_dark_inference=args.use_dark_inference)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_file", type=str, help="config file")
    parser.add_argument(
        "--model_weights",
        type=str,
        help="model weights",
    )
    parser.add_argument("--use_dark_inference", action="store_true")
    args = parser.parse_args()
    main(args)