voice-match / audio_index.py
freddyaboulton's picture
try
1412907
import os
import numpy as np
import faiss
import sqlite3
import torch
import librosa
import nemo.collections.asr as nemo_asr
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
"rimelabs/rimecaster"
)
speaker_model.freeze()
def get_embedding(row: dict) -> torch.Tensor | None:
# Ensure audio is mono
if row["audio"]["array"].ndim > 1:
audio_array = librosa.to_mono(
row["audio"]["array"].T
) # Transpose if shape is (samples, channels)
else:
audio_array = row["audio"]["array"]
# Resample for embedding (keep original for upload)
try:
audio_resampled = librosa.resample(
audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000
)
except Exception as e:
print(f"Error resampling audio: {e}. Skipping embedding for row.")
return None # Return None if resampling fails
audio_length = audio_resampled.shape[0]
device = speaker_model.device
audio_resampled = np.array([audio_resampled]) # Add batch dim for model
audio_signal, audio_signal_len = (
torch.tensor(audio_resampled, device=device, dtype=torch.float32),
torch.tensor([audio_length], device=device),
)
_, emb = speaker_model.forward(
input_signal=audio_signal, input_signal_length=audio_signal_len
)
del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio
return emb.detach().cpu().numpy() # Return the tensor
def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray):
row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}}
return get_embedding(row)
class AudioEmbeddingSystem:
def __init__(
self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768
):
"""
Initialize the audio embedding system.
Args:
model_name: HuggingFace model to use for embeddings
db_path: Path to SQLite database
index_path: Path to save FAISS index
vector_dim: Dimension of embedding vectors
use_quantization: Whether to use vector quantization (reduces size)
"""
self.db_path = db_path
self.index_path = index_path
self.vector_dim = vector_dim
self._init_db()
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
self.index = faiss.IndexFlatL2(vector_dim)
def _init_db(self):
"""Initialize SQLite database with required tables"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS audio_files (
id INTEGER PRIMARY KEY,
file_path TEXT UNIQUE,
vector_id INTEGER
)
""")
conn.commit()
conn.close()
def extract_embedding(self, row: dict):
"""Extract embedding from audio file"""
return get_embedding(row)
def add_audio(self, row):
"""Add audio file to the database and index"""
embedding = self.extract_embedding(row)
embedding_normalized = embedding.reshape(1, -1).astype(np.float32)
current_index_size = self.index.ntotal
self.index.add(embedding_normalized)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)",
(row["path"], current_index_size),
)
conn.commit()
conn.close()
faiss.write_index(self.index, self.index_path)
return current_index_size
def search(self, row: dict | tuple, top_k=5, least_similar=False):
"""
Search for similar audio files.
Either provide query_audio (path to audio file) or query_embedding (numpy array)
"""
if isinstance(row, dict):
query_embedding = self.extract_embedding(row)
else:
query_embedding = get_embedding_from_array(*row)
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
if least_similar:
query_embedding = -1 * query_embedding
distances, indices = self.index.search(query_embedding, top_k)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
results = []
for i, idx in enumerate(indices[0]):
cursor.execute(
"SELECT file_path FROM audio_files WHERE vector_id = ?",
(int(idx),),
)
row = cursor.fetchone()
if row:
results.append(
{
"path": row[0],
"distance": float(distances[0][i]),
"vector_id": int(idx),
}
)
conn.close()
return results