Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import numpy as np | |
import faiss | |
import sqlite3 | |
import torch | |
import librosa | |
import nemo.collections.asr as nemo_asr | |
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | |
"rimelabs/rimecaster" | |
) | |
speaker_model.freeze() | |
def get_embedding(row: dict) -> torch.Tensor | None: | |
# Ensure audio is mono | |
if row["audio"]["array"].ndim > 1: | |
audio_array = librosa.to_mono( | |
row["audio"]["array"].T | |
) # Transpose if shape is (samples, channels) | |
else: | |
audio_array = row["audio"]["array"] | |
# Resample for embedding (keep original for upload) | |
try: | |
audio_resampled = librosa.resample( | |
audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000 | |
) | |
except Exception as e: | |
print(f"Error resampling audio: {e}. Skipping embedding for row.") | |
return None # Return None if resampling fails | |
audio_length = audio_resampled.shape[0] | |
device = speaker_model.device | |
audio_resampled = np.array([audio_resampled]) # Add batch dim for model | |
audio_signal, audio_signal_len = ( | |
torch.tensor(audio_resampled, device=device, dtype=torch.float32), | |
torch.tensor([audio_length], device=device), | |
) | |
_, emb = speaker_model.forward( | |
input_signal=audio_signal, input_signal_length=audio_signal_len | |
) | |
del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio | |
return emb.detach().cpu().numpy() # Return the tensor | |
def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray): | |
row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}} | |
return get_embedding(row) | |
class AudioEmbeddingSystem: | |
def __init__( | |
self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768 | |
): | |
""" | |
Initialize the audio embedding system. | |
Args: | |
model_name: HuggingFace model to use for embeddings | |
db_path: Path to SQLite database | |
index_path: Path to save FAISS index | |
vector_dim: Dimension of embedding vectors | |
use_quantization: Whether to use vector quantization (reduces size) | |
""" | |
self.db_path = db_path | |
self.index_path = index_path | |
self.vector_dim = vector_dim | |
self._init_db() | |
if os.path.exists(index_path): | |
self.index = faiss.read_index(index_path) | |
else: | |
self.index = faiss.IndexFlatL2(vector_dim) | |
def _init_db(self): | |
"""Initialize SQLite database with required tables""" | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS audio_files ( | |
id INTEGER PRIMARY KEY, | |
file_path TEXT UNIQUE, | |
vector_id INTEGER | |
) | |
""") | |
conn.commit() | |
conn.close() | |
def extract_embedding(self, row: dict): | |
"""Extract embedding from audio file""" | |
return get_embedding(row) | |
def add_audio(self, row): | |
"""Add audio file to the database and index""" | |
embedding = self.extract_embedding(row) | |
embedding_normalized = embedding.reshape(1, -1).astype(np.float32) | |
current_index_size = self.index.ntotal | |
self.index.add(embedding_normalized) | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)", | |
(row["path"], current_index_size), | |
) | |
conn.commit() | |
conn.close() | |
faiss.write_index(self.index, self.index_path) | |
return current_index_size | |
def search(self, row: dict | tuple, top_k=5): | |
""" | |
Search for similar audio files. | |
Either provide query_audio (path to audio file) or query_embedding (numpy array) | |
""" | |
if isinstance(row, dict): | |
query_embedding = self.extract_embedding(row) | |
else: | |
query_embedding = get_embedding_from_array(*row) | |
query_embedding = query_embedding.reshape(1, -1).astype(np.float32) | |
distances, indices = self.index.search(query_embedding, top_k) | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
results = [] | |
for i, idx in enumerate(indices[0]): | |
cursor.execute( | |
"SELECT file_path FROM audio_files WHERE vector_id = ?", | |
(int(idx),), | |
) | |
row = cursor.fetchone() | |
if row: | |
results.append( | |
{ | |
"path": row[0], | |
"distance": float(distances[0][i]), | |
"vector_id": int(idx), | |
} | |
) | |
conn.close() | |
return results | |