import os import json import pickle import numpy as np import argparse from genception.utils import find_files def read_all_pkl(folder_path: str) -> dict: """ Read all the pickle files in the given folder path Args: folder_path: str: The path to the folder Returns: dict: The dictionary containing the file path as key and the pickle file content as value """ result_dict = dict() file_list = find_files(folder_path, {".pkl"}) for file_path in file_list: with open(file_path, "rb") as file: result_dict[file_path] = pickle.load(file) return result_dict def integrated_decay_area(scores: list[float]) -> float: """ Calculate the Integrated Decay Area (IDA) for the given scores Args: scores: list[float]: The list of scores Returns: float: The IDA score """ total_area = 0 for i, score in enumerate(scores): total_area += (i + 1) * score max_possible_area = sum(range(1, len(scores) + 1)) ida = total_area / max_possible_area if max_possible_area else 0 return ida def gc_score(folder_path: str, n_iter: int = None) -> tuple[float, list[float]]: """ Calculate the GC@T score for the given folder path Args: folder_path: str: The path to the folder n_iter: int: The number of iterations to consider for GC@T score Returns: tuple[float, list[float]]: The GC@T score and the list of GC scores for each file """ test_data = read_all_pkl(folder_path) all_gc_scores = [] for _, value in test_data.items(): sim_score = value["cosine_similarities"][1:] if n_iter is None: _gc = integrated_decay_area(sim_score) else: if len(value["cosine_similarities"]) >= n_iter: _gc = integrated_decay_area(sim_score[:n_iter]) else: continue all_gc_scores.append(_gc) return np.mean(all_gc_scores), all_gc_scores def main(): parser = argparse.ArgumentParser() parser.add_argument( "--results_path", type=str, help="Path to the folder containing the pickle files", required=True, ) parser.add_argument( "--t", type=int, help="Number of iterations to consider for GC@T score", required=True, ) args = parser.parse_args() # calculate GC@T score and save in results directory gc, all_gc_scores = gc_score(args.results_path, args.t) result = { "GC Score": gc, "All GC Scores": all_gc_scores, } results_path = os.path.join(args.results_path, f"GC@{str(args.t)}.json") with open(results_path, "w") as file: json.dump(result, file) if __name__ == "__main__": main()