PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
010952f verified
raw
history blame
868 Bytes
from valids import parser, main as valids_main
import os.path as osp
args = parser.parse_args()
args.target = "valid_accuracy"
args.best_biggest = True
args.best = True
args.last = 0
args.path_contains = None
res = valids_main(args, print_output=False)
grouped = {}
for k, v in res.items():
k = osp.dirname(k)
run = osp.dirname(k)
task = osp.basename(k)
val = v["valid_accuracy"]
if run not in grouped:
grouped[run] = {}
grouped[run][task] = val
for run, tasks in grouped.items():
print(run)
avg = sum(float(v) for v in tasks.values()) / len(tasks)
avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
try:
print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
except:
print(tasks)
print()