freddyaboulton HF Staff commited on
Commit
1e5834c
·
1 Parent(s): af73830
Files changed (4) hide show
  1. app.py +39 -3
  2. audio_index.py +153 -0
  3. requirements.txt +1 -1
  4. search.py +23 -0
app.py CHANGED
@@ -1,7 +1,43 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from .audio_index import AudioEmbeddingSystem
4
+ from .search import search
5
+ import pandas as pd
6
 
7
+ db_file = hf_hub_download(
8
+ repo_id="freddyaboulton/common-voice-english-audio", filename="audio_db.sqlite"
9
+ )
10
+ index_file = hf_hub_download(
11
+ repo_id="freddyaboulton/common-voice-english-audio", filename="audio_faiss.index"
12
+ )
13
 
14
+ audio_embedding_system = AudioEmbeddingSystem(db_path=db_file, index_path=index_file)
15
+
16
+
17
+ def audio_search(audio_tuple):
18
+ sample_rate, array = audio_tuple
19
+ array = array[: int(sample_rate * 10)]
20
+ rows = audio_embedding_system.search(sample_rate, array)
21
+ orig_rows = search(rows)
22
+ for row in rows:
23
+ path = row["path"]
24
+ for orig_row in orig_rows:
25
+ if orig_row["path"] == path:
26
+ row["sentence"] = orig_row["sentence"]
27
+ row["audio"] = [
28
+ "<audio src=" + orig_row["audio"]["src"] + " controls />"
29
+ ]
30
+ return pd.DataFrame(rows).sort_values(by="distance", ascending=True)
31
+
32
+
33
+ demo = gr.Interface(
34
+ fn=audio_search,
35
+ inputs=gr.Audio(
36
+ label="Record or upload a clip of your voice", sources=["upload", "microphone"]
37
+ ),
38
+ outputs=gr.Dataframe(
39
+ headers=["path", "audio", "sentence", "distance", "vector_id"],
40
+ datatype=["str", "markdown", "str", "number", "str"],
41
+ ),
42
+ )
43
  demo.launch()
audio_index.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import faiss
4
+ import sqlite3
5
+ import torch
6
+ import librosa
7
+
8
+ import nemo.collections.asr as nemo_asr
9
+
10
+ speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
11
+ "rimelabs/rimecaster"
12
+ )
13
+ speaker_model.freeze()
14
+
15
+
16
+ def get_embedding(row: dict) -> torch.Tensor | None:
17
+ # Ensure audio is mono
18
+ if row["audio"]["array"].ndim > 1:
19
+ audio_array = librosa.to_mono(
20
+ row["audio"]["array"].T
21
+ ) # Transpose if shape is (samples, channels)
22
+ else:
23
+ audio_array = row["audio"]["array"]
24
+
25
+ # Resample for embedding (keep original for upload)
26
+ try:
27
+ audio_resampled = librosa.resample(
28
+ audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000
29
+ )
30
+ except Exception as e:
31
+ print(f"Error resampling audio: {e}. Skipping embedding for row.")
32
+ return None # Return None if resampling fails
33
+
34
+ audio_length = audio_resampled.shape[0]
35
+ device = speaker_model.device
36
+ audio_resampled = np.array([audio_resampled]) # Add batch dim for model
37
+ audio_signal, audio_signal_len = (
38
+ torch.tensor(audio_resampled, device=device, dtype=torch.float32),
39
+ torch.tensor([audio_length], device=device),
40
+ )
41
+ _, emb = speaker_model.forward(
42
+ input_signal=audio_signal, input_signal_length=audio_signal_len
43
+ )
44
+ del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio
45
+ return emb.detach().cpu() # Return the tensor
46
+
47
+
48
+ def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray):
49
+ row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}}
50
+ return get_embedding(row)
51
+
52
+
53
+ class AudioEmbeddingSystem:
54
+ def __init__(
55
+ self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768
56
+ ):
57
+ """
58
+ Initialize the audio embedding system.
59
+
60
+ Args:
61
+ model_name: HuggingFace model to use for embeddings
62
+ db_path: Path to SQLite database
63
+ index_path: Path to save FAISS index
64
+ vector_dim: Dimension of embedding vectors
65
+ use_quantization: Whether to use vector quantization (reduces size)
66
+ """
67
+ self.db_path = db_path
68
+ self.index_path = index_path
69
+ self.vector_dim = vector_dim
70
+
71
+ self._init_db()
72
+
73
+ if os.path.exists(index_path):
74
+ self.index = faiss.read_index(index_path)
75
+ else:
76
+ self.index = faiss.IndexFlatL2(vector_dim)
77
+
78
+ def _init_db(self):
79
+ """Initialize SQLite database with required tables"""
80
+ conn = sqlite3.connect(self.db_path)
81
+ cursor = conn.cursor()
82
+
83
+ cursor.execute("""
84
+ CREATE TABLE IF NOT EXISTS audio_files (
85
+ id INTEGER PRIMARY KEY,
86
+ file_path TEXT UNIQUE,
87
+ vector_id INTEGER
88
+ )
89
+ """)
90
+ conn.commit()
91
+ conn.close()
92
+
93
+ def extract_embedding(self, row: dict):
94
+ """Extract embedding from audio file"""
95
+ return get_embedding(row)
96
+
97
+ def add_audio(self, row):
98
+ """Add audio file to the database and index"""
99
+ embedding = self.extract_embedding(row)
100
+
101
+ embedding_normalized = embedding.reshape(1, -1).astype(np.float32)
102
+
103
+ current_index_size = self.index.ntotal
104
+ self.index.add(embedding_normalized)
105
+
106
+ conn = sqlite3.connect(self.db_path)
107
+ cursor = conn.cursor()
108
+ cursor.execute(
109
+ "INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)",
110
+ (row["path"], current_index_size),
111
+ )
112
+ conn.commit()
113
+ conn.close()
114
+
115
+ faiss.write_index(self.index, self.index_path)
116
+
117
+ return current_index_size
118
+
119
+ def search(self, row: dict | tuple, top_k=5):
120
+ """
121
+ Search for similar audio files.
122
+ Either provide query_audio (path to audio file) or query_embedding (numpy array)
123
+ """
124
+ if isinstance(row, dict):
125
+ query_embedding = self.extract_embedding(row)
126
+ else:
127
+ query_embedding = get_embedding_from_array(row)
128
+
129
+ query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
130
+
131
+ distances, indices = self.index.search(query_embedding, top_k)
132
+
133
+ conn = sqlite3.connect(self.db_path)
134
+ cursor = conn.cursor()
135
+
136
+ results = []
137
+ for i, idx in enumerate(indices[0]):
138
+ cursor.execute(
139
+ "SELECT file_path, metadata FROM audio_files WHERE vector_id = ?",
140
+ (int(idx),),
141
+ )
142
+ row = cursor.fetchone()
143
+ if row:
144
+ results.append(
145
+ {
146
+ "file_path": row[0],
147
+ "distance": float(distances[0][i]),
148
+ "vector_id": int(idx),
149
+ }
150
+ )
151
+
152
+ conn.close()
153
+ return results
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  faiss-cpu
2
- nemo_toolkit['all']
 
1
  faiss-cpu
2
+ nemo_toolkit[all]
search.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+
4
+ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
5
+
6
+ dataset = "mozilla-foundation/common_voice_17_0"
7
+ config = "en"
8
+ split = "validation"
9
+
10
+
11
+ def search(rows: list[dict]):
12
+ file_paths_to_find = [row["path"] for row in rows]
13
+
14
+ paths_in_clause = ", ".join([f"'{path}'" for path in file_paths_to_find])
15
+ where_clause = f'"path" IN ({paths_in_clause})'
16
+
17
+ api_url = f"https://datasets-server.huggingface.co/filter?dataset={dataset}&config={config}&split={split}&where={where_clause}&offset=0"
18
+
19
+ response = requests.get(api_url, headers=headers)
20
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
21
+ data = response.json()
22
+
23
+ return data.get("rows", [])