import os import numpy as np import faiss import sqlite3 import torch import librosa import nemo.collections.asr as nemo_asr speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( "rimelabs/rimecaster" ) speaker_model.freeze() def get_embedding(row: dict) -> torch.Tensor | None: # Ensure audio is mono if row["audio"]["array"].ndim > 1: audio_array = librosa.to_mono( row["audio"]["array"].T ) # Transpose if shape is (samples, channels) else: audio_array = row["audio"]["array"] # Resample for embedding (keep original for upload) try: audio_resampled = librosa.resample( audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000 ) except Exception as e: print(f"Error resampling audio: {e}. Skipping embedding for row.") return None # Return None if resampling fails audio_length = audio_resampled.shape[0] device = speaker_model.device audio_resampled = np.array([audio_resampled]) # Add batch dim for model audio_signal, audio_signal_len = ( torch.tensor(audio_resampled, device=device, dtype=torch.float32), torch.tensor([audio_length], device=device), ) _, emb = speaker_model.forward( input_signal=audio_signal, input_signal_length=audio_signal_len ) del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio return emb.detach().cpu().numpy() # Return the tensor def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray): row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}} return get_embedding(row) class AudioEmbeddingSystem: def __init__( self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768 ): """ Initialize the audio embedding system. Args: model_name: HuggingFace model to use for embeddings db_path: Path to SQLite database index_path: Path to save FAISS index vector_dim: Dimension of embedding vectors use_quantization: Whether to use vector quantization (reduces size) """ self.db_path = db_path self.index_path = index_path self.vector_dim = vector_dim self._init_db() if os.path.exists(index_path): self.index = faiss.read_index(index_path) else: self.index = faiss.IndexFlatL2(vector_dim) def _init_db(self): """Initialize SQLite database with required tables""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS audio_files ( id INTEGER PRIMARY KEY, file_path TEXT UNIQUE, vector_id INTEGER ) """) conn.commit() conn.close() def extract_embedding(self, row: dict): """Extract embedding from audio file""" return get_embedding(row) def add_audio(self, row): """Add audio file to the database and index""" embedding = self.extract_embedding(row) embedding_normalized = embedding.reshape(1, -1).astype(np.float32) current_index_size = self.index.ntotal self.index.add(embedding_normalized) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)", (row["path"], current_index_size), ) conn.commit() conn.close() faiss.write_index(self.index, self.index_path) return current_index_size def search(self, row: dict | tuple, top_k=5, least_similar=False): """ Search for similar audio files. Either provide query_audio (path to audio file) or query_embedding (numpy array) """ if isinstance(row, dict): query_embedding = self.extract_embedding(row) else: query_embedding = get_embedding_from_array(*row) query_embedding = query_embedding.reshape(1, -1).astype(np.float32) if least_similar: query_embedding = -1 * query_embedding distances, indices = self.index.search(query_embedding, top_k) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() results = [] for i, idx in enumerate(indices[0]): cursor.execute( "SELECT file_path FROM audio_files WHERE vector_id = ?", (int(idx),), ) row = cursor.fetchone() if row: results.append( { "path": row[0], "distance": float(distances[0][i]), "vector_id": int(idx), } ) conn.close() return results