File size: 5,194 Bytes
010952f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os.path as osp
import re
from collections import defaultdict
from valids import parser, main as valids_main
TASK_TO_METRIC = {
"cola": "mcc",
"qnli": "accuracy",
"mrpc": "acc_and_f1",
"rte": "accuracy",
"sst_2": "accuracy",
"mnli": "accuracy",
"qqp": "acc_and_f1",
"sts_b": "pearson_and_spearman",
}
TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]
def get_best_stat_str(task_vals, show_subdir):
task_to_best_val = {}
task_to_best_dir = {}
for task, subdir_to_val in task_vals.items():
task_to_best_val[task] = max(subdir_to_val.values())
task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])
# import pdb; pdb.set_trace()
N1 = len(task_to_best_val)
N2 = len([k for k in task_to_best_val if k != "rte"])
avg1 = sum(task_to_best_val.values()) / N1
avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2
try:
msg = ""
for task in TASKS:
dir = task_to_best_dir.get(task, 'null')
val = task_to_best_val.get(task, -100)
msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
msg += f"{avg1:.2f}\t{avg2:.2f}"
except Exception as e:
msg = str(e)
msg += str(sorted(task_vals.items()))
return msg
def get_all_stat_str(task_vals):
msg = ""
for task in [task for task in TASKS if task in task_vals]:
msg += f"=== {task}\n"
for subdir in sorted(task_vals[task].keys()):
msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
return msg
def get_tabular_stat_str(task_vals):
"""assume subdir is <param>/run_*/0"""
msg = ""
for task in [task for task in TASKS if task in task_vals]:
msg += f"=== {task}\n"
param_to_runs = defaultdict(dict)
for subdir in task_vals[task]:
match = re.match("(.*)/(run_.*)/0", subdir)
assert match, "subdir"
param, run = match.groups()
param_to_runs[param][run] = task_vals[task][subdir]
params = sorted(param_to_runs, key=lambda x: float(x))
runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
msg += ("runs:" + "\t".join(runs) + "\n")
msg += ("params:" + "\t".join(params) + "\n")
for param in params:
msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
msg += "\n"
# for subdir in sorted(task_vals[task].keys()):
# msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
return msg
def main():
parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
parser.add_argument("--print_mode", default="best", help="best|all|tabular")
parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
parser.add_argument("--override_target", default="valid_accuracy", help="override target")
args = parser.parse_args()
args.target = args.override_target
args.best_biggest = True
args.best = True
args.last = 0
args.path_contains = None
res = valids_main(args, print_output=False)
grouped_acc = {}
grouped_met = {} # use official metric for each task
for path, v in res.items():
path = "/".join([args.base, path])
path = re.sub("//*", "/", path)
match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
if not match:
continue
run, task, subdir = match.groups()
if run not in grouped_acc:
grouped_acc[run] = {}
grouped_met[run] = {}
if task not in grouped_acc[run]:
grouped_acc[run][task] = {}
grouped_met[run][task] = {}
if v is not None:
grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
else:
print(f"{path} has None return")
header = "\t".join(TASKS)
for run in sorted(grouped_acc):
print(run)
if args.print_mode == "all":
if args.show_glue:
print("===== GLUE =====")
print(get_all_stat_str(grouped_met[run]))
else:
print("===== ACC =====")
print(get_all_stat_str(grouped_acc[run]))
elif args.print_mode == "best":
print(f" {header}")
if args.show_glue:
print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
else:
print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
elif args.print_mode == "tabular":
if args.show_glue:
print("===== GLUE =====")
print(get_tabular_stat_str(grouped_met[run]))
else:
print("===== ACC =====")
print(get_tabular_stat_str(grouped_acc[run]))
else:
raise ValueError(args.print_mode)
print()
if __name__ == "__main__":
main()
|