voice-match / search.py
freddyaboulton's picture
Add Prompts
e1ef382
import requests
import os
import random
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
dataset = "mozilla-foundation/common_voice_17_0"
config = "en"
def _search(paths: list[str]):
if len(paths) == 0:
return []
if paths[0].startswith("en_train"):
split = "train"
else:
split = "validation"
paths_in_clause = ", ".join([f"'{path}'" for path in paths])
where_clause = f'"path" IN ({paths_in_clause})'
api_url = f"https://datasets-server.huggingface.co/filter?dataset={dataset}&config={config}&split={split}&where={where_clause}&offset=0"
response = requests.get(api_url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
data = response.json()
return data.get("rows", [])
def get_prompt():
"""Get a random sentence from the Common Voice dataset"""
offset = random.randint(0, 100_000)
api_url = f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split=train&offset={offset}&length=1"
response = requests.get(api_url, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
data = response.json()
return data.get("rows", [])[0]["row"]["sentence"]
def search(rows: list[dict]):
file_paths_to_find = [row["path"] for row in rows]
train_paths = []
validation_paths = []
for path in file_paths_to_find:
if path.startswith("en_train"):
train_paths.append(path)
else:
validation_paths.append(path)
train_rows = _search(train_paths)
validation_rows = _search(validation_paths)
return train_rows + validation_rows