PyTorch
ssl-aasist
custom_code
File size: 868 Bytes
010952f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from valids import parser, main as valids_main
import os.path as osp


args = parser.parse_args()
args.target = "valid_accuracy"
args.best_biggest = True
args.best = True
args.last = 0
args.path_contains = None

res =  valids_main(args, print_output=False)

grouped = {}
for k, v in res.items():
    k = osp.dirname(k)
    run = osp.dirname(k)
    task = osp.basename(k)
    val = v["valid_accuracy"]

    if run not in grouped:
        grouped[run] = {}

    grouped[run][task] = val

for run, tasks in grouped.items():
    print(run)
    avg = sum(float(v) for v in tasks.values()) / len(tasks)
    avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
    try:
        print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
    except:
        print(tasks)
    print()