|
from pathlib import Path |
|
import os |
|
import argparse |
|
import random |
|
import numpy as np |
|
from sklearn.utils import shuffle |
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
this is a standalone script to process a km file |
|
specifically, to dedup or remove tokens that repeat less |
|
than k times in a row |
|
""" |
|
parser = argparse.ArgumentParser(description="") |
|
parser.add_argument("km", type=str, help="path to km file") |
|
parser.add_argument("--destdir", required=True, type=str) |
|
parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set") |
|
parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set") |
|
parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file") |
|
parser.add_argument("--seed", type=int, default=42, help="") |
|
args = parser.parse_args() |
|
|
|
np.random.seed(args.seed) |
|
random.seed(args.seed) |
|
|
|
os.makedirs(args.destdir, exist_ok=True) |
|
km = open(args.km, "r").readlines() |
|
|
|
if args.shuffle: |
|
km = shuffle(km) |
|
print(f"shuffled") |
|
|
|
N = len(km) |
|
N_tt = int(N * args.test_percent) |
|
N_cv = int(N * args.valid_percent) |
|
N_tr = N - N_tt - N_cv |
|
|
|
train_km = km[:N_tr] |
|
valid_km = km[N_tr:N_tr + N_cv] |
|
test_km = km[N_tr + N_cv:] |
|
|
|
dir = Path(args.destdir) |
|
open(dir / f"train.km", "w").writelines(train_km) |
|
open(dir / f"valid.km", "w").writelines(valid_km) |
|
open(dir / f"test.km", "w").writelines(test_km) |
|
print(f"train: {len(train_km)}") |
|
print(f"valid: {len(valid_km)}") |
|
print(f"test: {len(test_km)}") |
|
print("done") |
|
|