PyTorch
ssl-aasist
custom_code
File size: 5,194 Bytes
010952f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os.path as osp
import re
from collections import defaultdict

from valids import parser, main as valids_main


TASK_TO_METRIC = {
    "cola": "mcc",
    "qnli": "accuracy",
    "mrpc": "acc_and_f1",
    "rte": "accuracy",
    "sst_2": "accuracy",
    "mnli": "accuracy",
    "qqp": "acc_and_f1",
    "sts_b": "pearson_and_spearman",
}
TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]


def get_best_stat_str(task_vals, show_subdir):
    task_to_best_val = {}
    task_to_best_dir = {}
    for task, subdir_to_val in task_vals.items():
        task_to_best_val[task] = max(subdir_to_val.values())
        task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])

    # import pdb; pdb.set_trace()
    N1 = len(task_to_best_val)
    N2 = len([k for k in task_to_best_val if k != "rte"])
    avg1 = sum(task_to_best_val.values()) / N1
    avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2

    try:
        msg = ""
        for task in TASKS:
            dir = task_to_best_dir.get(task, 'null')
            val = task_to_best_val.get(task, -100)
            msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
        msg += f"{avg1:.2f}\t{avg2:.2f}"
    except Exception as e:
        msg = str(e)
        msg += str(sorted(task_vals.items()))
    return msg

def get_all_stat_str(task_vals):
    msg = ""
    for task in [task for task in TASKS if task in task_vals]:
        msg += f"=== {task}\n"
        for subdir in sorted(task_vals[task].keys()):
            msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
    return msg

def get_tabular_stat_str(task_vals):
    """assume subdir is <param>/run_*/0"""
    msg = ""
    for task in [task for task in TASKS if task in task_vals]:
        msg += f"=== {task}\n"
        param_to_runs = defaultdict(dict)
        for subdir in task_vals[task]:
            match = re.match("(.*)/(run_.*)/0", subdir)
            assert match, "subdir"
            param, run = match.groups()
            param_to_runs[param][run] = task_vals[task][subdir]
        params = sorted(param_to_runs, key=lambda x: float(x))
        runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
        msg += ("runs:" + "\t".join(runs) + "\n")
        msg += ("params:" + "\t".join(params) + "\n")
        for param in params:
            msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
            msg += "\n"
        # for subdir in sorted(task_vals[task].keys()):
        #     msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
    return msg

   

def main():
    parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
    parser.add_argument("--print_mode", default="best", help="best|all|tabular")
    parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
    parser.add_argument("--override_target", default="valid_accuracy", help="override target")

    args = parser.parse_args()
    args.target = args.override_target
    args.best_biggest = True
    args.best = True
    args.last = 0
    args.path_contains = None
    
    res =  valids_main(args, print_output=False)
    grouped_acc = {}
    grouped_met = {}  # use official metric for each task
    for path, v in res.items():
        path = "/".join([args.base, path])
        path = re.sub("//*", "/", path)
        match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
        if not match:
            continue
        run, task, subdir = match.groups()

        if run not in grouped_acc:
            grouped_acc[run] = {}
            grouped_met[run] = {}
        if task not in grouped_acc[run]:
            grouped_acc[run][task] = {}
            grouped_met[run][task] = {}

        if v is not None:
            grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
            grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
        else:
            print(f"{path} has None return")

    header = "\t".join(TASKS)
    for run in sorted(grouped_acc):
        print(run)
        if args.print_mode == "all":
            if args.show_glue:
                print("===== GLUE =====")
                print(get_all_stat_str(grouped_met[run]))
            else:
                print("===== ACC =====")
                print(get_all_stat_str(grouped_acc[run]))
        elif args.print_mode == "best":
            print(f"      {header}")
            if args.show_glue:
                print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
            else:
                print(f"ACC:  {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
        elif args.print_mode == "tabular":
            if args.show_glue:
                print("===== GLUE =====")
                print(get_tabular_stat_str(grouped_met[run]))
            else:
                print("===== ACC =====")
                print(get_tabular_stat_str(grouped_acc[run]))
        else:
            raise ValueError(args.print_mode)
        print()

if __name__ == "__main__":
    main()