GenCeption / genception /evaluation.py
cao-lele
initial commit
0724c4e
raw
history blame
2.77 kB
import os
import json
import pickle
import numpy as np
import argparse
from genception.utils import find_files
def read_all_pkl(folder_path: str) -> dict:
"""
Read all the pickle files in the given folder path
Args:
folder_path: str: The path to the folder
Returns:
dict: The dictionary containing the file path as key and the pickle file content as value
"""
result_dict = dict()
file_list = find_files(folder_path, {".pkl"})
for file_path in file_list:
with open(file_path, "rb") as file:
result_dict[file_path] = pickle.load(file)
return result_dict
def integrated_decay_area(scores: list[float]) -> float:
"""
Calculate the Integrated Decay Area (IDA) for the given scores
Args:
scores: list[float]: The list of scores
Returns:
float: The IDA score
"""
total_area = 0
for i, score in enumerate(scores):
total_area += (i + 1) * score
max_possible_area = sum(range(1, len(scores) + 1))
ida = total_area / max_possible_area if max_possible_area else 0
return ida
def gc_score(folder_path: str, n_iter: int = None) -> tuple[float, list[float]]:
"""
Calculate the GC@T score for the given folder path
Args:
folder_path: str: The path to the folder
n_iter: int: The number of iterations to consider for GC@T score
Returns:
tuple[float, list[float]]: The GC@T score and the list of GC scores for each file
"""
test_data = read_all_pkl(folder_path)
all_gc_scores = []
for _, value in test_data.items():
sim_score = value["cosine_similarities"][1:]
if n_iter is None:
_gc = integrated_decay_area(sim_score)
else:
if len(value["cosine_similarities"]) >= n_iter:
_gc = integrated_decay_area(sim_score[:n_iter])
else:
continue
all_gc_scores.append(_gc)
return np.mean(all_gc_scores), all_gc_scores
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--results_path",
type=str,
help="Path to the folder containing the pickle files",
required=True,
)
parser.add_argument(
"--t",
type=int,
help="Number of iterations to consider for GC@T score",
required=True,
)
args = parser.parse_args()
# calculate GC@T score and save in results directory
gc, all_gc_scores = gc_score(args.results_path, args.t)
result = {
"GC Score": gc,
"All GC Scores": all_gc_scores,
}
results_path = os.path.join(args.results_path, f"GC@{str(args.t)}.json")
with open(results_path, "w") as file:
json.dump(result, file)
if __name__ == "__main__":
main()