Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
1e5834c
1
Parent(s):
af73830
commit
Browse files- app.py +39 -3
- audio_index.py +153 -0
- requirements.txt +1 -1
- search.py +23 -0
app.py
CHANGED
@@ -1,7 +1,43 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
from .audio_index import AudioEmbeddingSystem
|
4 |
+
from .search import search
|
5 |
+
import pandas as pd
|
6 |
|
7 |
+
db_file = hf_hub_download(
|
8 |
+
repo_id="freddyaboulton/common-voice-english-audio", filename="audio_db.sqlite"
|
9 |
+
)
|
10 |
+
index_file = hf_hub_download(
|
11 |
+
repo_id="freddyaboulton/common-voice-english-audio", filename="audio_faiss.index"
|
12 |
+
)
|
13 |
|
14 |
+
audio_embedding_system = AudioEmbeddingSystem(db_path=db_file, index_path=index_file)
|
15 |
+
|
16 |
+
|
17 |
+
def audio_search(audio_tuple):
|
18 |
+
sample_rate, array = audio_tuple
|
19 |
+
array = array[: int(sample_rate * 10)]
|
20 |
+
rows = audio_embedding_system.search(sample_rate, array)
|
21 |
+
orig_rows = search(rows)
|
22 |
+
for row in rows:
|
23 |
+
path = row["path"]
|
24 |
+
for orig_row in orig_rows:
|
25 |
+
if orig_row["path"] == path:
|
26 |
+
row["sentence"] = orig_row["sentence"]
|
27 |
+
row["audio"] = [
|
28 |
+
"<audio src=" + orig_row["audio"]["src"] + " controls />"
|
29 |
+
]
|
30 |
+
return pd.DataFrame(rows).sort_values(by="distance", ascending=True)
|
31 |
+
|
32 |
+
|
33 |
+
demo = gr.Interface(
|
34 |
+
fn=audio_search,
|
35 |
+
inputs=gr.Audio(
|
36 |
+
label="Record or upload a clip of your voice", sources=["upload", "microphone"]
|
37 |
+
),
|
38 |
+
outputs=gr.Dataframe(
|
39 |
+
headers=["path", "audio", "sentence", "distance", "vector_id"],
|
40 |
+
datatype=["str", "markdown", "str", "number", "str"],
|
41 |
+
),
|
42 |
+
)
|
43 |
demo.launch()
|
audio_index.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import faiss
|
4 |
+
import sqlite3
|
5 |
+
import torch
|
6 |
+
import librosa
|
7 |
+
|
8 |
+
import nemo.collections.asr as nemo_asr
|
9 |
+
|
10 |
+
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
|
11 |
+
"rimelabs/rimecaster"
|
12 |
+
)
|
13 |
+
speaker_model.freeze()
|
14 |
+
|
15 |
+
|
16 |
+
def get_embedding(row: dict) -> torch.Tensor | None:
|
17 |
+
# Ensure audio is mono
|
18 |
+
if row["audio"]["array"].ndim > 1:
|
19 |
+
audio_array = librosa.to_mono(
|
20 |
+
row["audio"]["array"].T
|
21 |
+
) # Transpose if shape is (samples, channels)
|
22 |
+
else:
|
23 |
+
audio_array = row["audio"]["array"]
|
24 |
+
|
25 |
+
# Resample for embedding (keep original for upload)
|
26 |
+
try:
|
27 |
+
audio_resampled = librosa.resample(
|
28 |
+
audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000
|
29 |
+
)
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error resampling audio: {e}. Skipping embedding for row.")
|
32 |
+
return None # Return None if resampling fails
|
33 |
+
|
34 |
+
audio_length = audio_resampled.shape[0]
|
35 |
+
device = speaker_model.device
|
36 |
+
audio_resampled = np.array([audio_resampled]) # Add batch dim for model
|
37 |
+
audio_signal, audio_signal_len = (
|
38 |
+
torch.tensor(audio_resampled, device=device, dtype=torch.float32),
|
39 |
+
torch.tensor([audio_length], device=device),
|
40 |
+
)
|
41 |
+
_, emb = speaker_model.forward(
|
42 |
+
input_signal=audio_signal, input_signal_length=audio_signal_len
|
43 |
+
)
|
44 |
+
del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio
|
45 |
+
return emb.detach().cpu() # Return the tensor
|
46 |
+
|
47 |
+
|
48 |
+
def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray):
|
49 |
+
row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}}
|
50 |
+
return get_embedding(row)
|
51 |
+
|
52 |
+
|
53 |
+
class AudioEmbeddingSystem:
|
54 |
+
def __init__(
|
55 |
+
self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Initialize the audio embedding system.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
model_name: HuggingFace model to use for embeddings
|
62 |
+
db_path: Path to SQLite database
|
63 |
+
index_path: Path to save FAISS index
|
64 |
+
vector_dim: Dimension of embedding vectors
|
65 |
+
use_quantization: Whether to use vector quantization (reduces size)
|
66 |
+
"""
|
67 |
+
self.db_path = db_path
|
68 |
+
self.index_path = index_path
|
69 |
+
self.vector_dim = vector_dim
|
70 |
+
|
71 |
+
self._init_db()
|
72 |
+
|
73 |
+
if os.path.exists(index_path):
|
74 |
+
self.index = faiss.read_index(index_path)
|
75 |
+
else:
|
76 |
+
self.index = faiss.IndexFlatL2(vector_dim)
|
77 |
+
|
78 |
+
def _init_db(self):
|
79 |
+
"""Initialize SQLite database with required tables"""
|
80 |
+
conn = sqlite3.connect(self.db_path)
|
81 |
+
cursor = conn.cursor()
|
82 |
+
|
83 |
+
cursor.execute("""
|
84 |
+
CREATE TABLE IF NOT EXISTS audio_files (
|
85 |
+
id INTEGER PRIMARY KEY,
|
86 |
+
file_path TEXT UNIQUE,
|
87 |
+
vector_id INTEGER
|
88 |
+
)
|
89 |
+
""")
|
90 |
+
conn.commit()
|
91 |
+
conn.close()
|
92 |
+
|
93 |
+
def extract_embedding(self, row: dict):
|
94 |
+
"""Extract embedding from audio file"""
|
95 |
+
return get_embedding(row)
|
96 |
+
|
97 |
+
def add_audio(self, row):
|
98 |
+
"""Add audio file to the database and index"""
|
99 |
+
embedding = self.extract_embedding(row)
|
100 |
+
|
101 |
+
embedding_normalized = embedding.reshape(1, -1).astype(np.float32)
|
102 |
+
|
103 |
+
current_index_size = self.index.ntotal
|
104 |
+
self.index.add(embedding_normalized)
|
105 |
+
|
106 |
+
conn = sqlite3.connect(self.db_path)
|
107 |
+
cursor = conn.cursor()
|
108 |
+
cursor.execute(
|
109 |
+
"INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)",
|
110 |
+
(row["path"], current_index_size),
|
111 |
+
)
|
112 |
+
conn.commit()
|
113 |
+
conn.close()
|
114 |
+
|
115 |
+
faiss.write_index(self.index, self.index_path)
|
116 |
+
|
117 |
+
return current_index_size
|
118 |
+
|
119 |
+
def search(self, row: dict | tuple, top_k=5):
|
120 |
+
"""
|
121 |
+
Search for similar audio files.
|
122 |
+
Either provide query_audio (path to audio file) or query_embedding (numpy array)
|
123 |
+
"""
|
124 |
+
if isinstance(row, dict):
|
125 |
+
query_embedding = self.extract_embedding(row)
|
126 |
+
else:
|
127 |
+
query_embedding = get_embedding_from_array(row)
|
128 |
+
|
129 |
+
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
130 |
+
|
131 |
+
distances, indices = self.index.search(query_embedding, top_k)
|
132 |
+
|
133 |
+
conn = sqlite3.connect(self.db_path)
|
134 |
+
cursor = conn.cursor()
|
135 |
+
|
136 |
+
results = []
|
137 |
+
for i, idx in enumerate(indices[0]):
|
138 |
+
cursor.execute(
|
139 |
+
"SELECT file_path, metadata FROM audio_files WHERE vector_id = ?",
|
140 |
+
(int(idx),),
|
141 |
+
)
|
142 |
+
row = cursor.fetchone()
|
143 |
+
if row:
|
144 |
+
results.append(
|
145 |
+
{
|
146 |
+
"file_path": row[0],
|
147 |
+
"distance": float(distances[0][i]),
|
148 |
+
"vector_id": int(idx),
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
conn.close()
|
153 |
+
return results
|
requirements.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
faiss-cpu
|
2 |
-
nemo_toolkit[
|
|
|
1 |
faiss-cpu
|
2 |
+
nemo_toolkit[all]
|
search.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
|
4 |
+
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
|
5 |
+
|
6 |
+
dataset = "mozilla-foundation/common_voice_17_0"
|
7 |
+
config = "en"
|
8 |
+
split = "validation"
|
9 |
+
|
10 |
+
|
11 |
+
def search(rows: list[dict]):
|
12 |
+
file_paths_to_find = [row["path"] for row in rows]
|
13 |
+
|
14 |
+
paths_in_clause = ", ".join([f"'{path}'" for path in file_paths_to_find])
|
15 |
+
where_clause = f'"path" IN ({paths_in_clause})'
|
16 |
+
|
17 |
+
api_url = f"https://datasets-server.huggingface.co/filter?dataset={dataset}&config={config}&split={split}&where={where_clause}&offset=0"
|
18 |
+
|
19 |
+
response = requests.get(api_url, headers=headers)
|
20 |
+
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
21 |
+
data = response.json()
|
22 |
+
|
23 |
+
return data.get("rows", [])
|