freddyaboulton HF Staff commited on
Commit
448ae42
·
1 Parent(s): d2e72fa

train search

Browse files
Files changed (1) hide show
  1. search.py +31 -1
search.py CHANGED
@@ -5,11 +5,41 @@ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
5
 
6
  dataset = "mozilla-foundation/common_voice_17_0"
7
  config = "en"
8
- split = "validation"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def search(rows: list[dict]):
12
  file_paths_to_find = [row["path"] for row in rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  paths_in_clause = ", ".join([f"'{path}'" for path in file_paths_to_find])
15
  where_clause = f'"path" IN ({paths_in_clause})'
 
5
 
6
  dataset = "mozilla-foundation/common_voice_17_0"
7
  config = "en"
 
8
 
9
+ def _search(paths: list[str]):
10
+ if paths[0].startswith("en_train"):
11
+ split = "train"
12
+ else:
13
+ split = "validation"
14
+
15
+
16
+ paths_in_clause = ", ".join([f"'{path}'" for path in paths])
17
+ where_clause = f'"path" IN ({paths_in_clause})'
18
+
19
+ api_url = f"https://datasets-server.huggingface.co/filter?dataset={dataset}&config={config}&split={split}&where={where_clause}&offset=0"
20
+
21
+ response = requests.get(api_url, headers=headers)
22
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
23
+ data = response.json()
24
+
25
+ return data.get("rows", [])
26
 
27
  def search(rows: list[dict]):
28
  file_paths_to_find = [row["path"] for row in rows]
29
+ train_paths = []
30
+ validation_paths = []
31
+ for path in file_paths_to_find:
32
+ if path.startswith("en_train"):
33
+ train_paths.append(path)
34
+ else:
35
+ validation_paths.append(path)
36
+
37
+ train_rows = _search(train_paths)
38
+ validation_rows = _search(validation_paths)
39
+
40
+ return train_rows + validation_rows
41
+
42
+
43
 
44
  paths_in_clause = ", ".join([f"'{path}'" for path in file_paths_to_find])
45
  where_clause = f'"path" IN ({paths_in_clause})'