Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,849 Bytes
1e5834c 775cbf6 1e5834c c3277a4 1e5834c cd01a28 1e5834c 1d6a544 1e5834c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import numpy as np
import faiss
import sqlite3
import torch
import librosa
import nemo.collections.asr as nemo_asr
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
"rimelabs/rimecaster"
)
speaker_model.freeze()
def get_embedding(row: dict) -> torch.Tensor | None:
# Ensure audio is mono
if row["audio"]["array"].ndim > 1:
audio_array = librosa.to_mono(
row["audio"]["array"].T
) # Transpose if shape is (samples, channels)
else:
audio_array = row["audio"]["array"]
# Resample for embedding (keep original for upload)
try:
audio_resampled = librosa.resample(
audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000
)
except Exception as e:
print(f"Error resampling audio: {e}. Skipping embedding for row.")
return None # Return None if resampling fails
audio_length = audio_resampled.shape[0]
device = speaker_model.device
audio_resampled = np.array([audio_resampled]) # Add batch dim for model
audio_signal, audio_signal_len = (
torch.tensor(audio_resampled, device=device, dtype=torch.float32),
torch.tensor([audio_length], device=device),
)
_, emb = speaker_model.forward(
input_signal=audio_signal, input_signal_length=audio_signal_len
)
del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio
return emb.detach().cpu().numpy() # Return the tensor
def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray):
row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}}
return get_embedding(row)
class AudioEmbeddingSystem:
def __init__(
self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768
):
"""
Initialize the audio embedding system.
Args:
model_name: HuggingFace model to use for embeddings
db_path: Path to SQLite database
index_path: Path to save FAISS index
vector_dim: Dimension of embedding vectors
use_quantization: Whether to use vector quantization (reduces size)
"""
self.db_path = db_path
self.index_path = index_path
self.vector_dim = vector_dim
self._init_db()
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
self.index = faiss.IndexFlatL2(vector_dim)
def _init_db(self):
"""Initialize SQLite database with required tables"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS audio_files (
id INTEGER PRIMARY KEY,
file_path TEXT UNIQUE,
vector_id INTEGER
)
""")
conn.commit()
conn.close()
def extract_embedding(self, row: dict):
"""Extract embedding from audio file"""
return get_embedding(row)
def add_audio(self, row):
"""Add audio file to the database and index"""
embedding = self.extract_embedding(row)
embedding_normalized = embedding.reshape(1, -1).astype(np.float32)
current_index_size = self.index.ntotal
self.index.add(embedding_normalized)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)",
(row["path"], current_index_size),
)
conn.commit()
conn.close()
faiss.write_index(self.index, self.index_path)
return current_index_size
def search(self, row: dict | tuple, top_k=5):
"""
Search for similar audio files.
Either provide query_audio (path to audio file) or query_embedding (numpy array)
"""
if isinstance(row, dict):
query_embedding = self.extract_embedding(row)
else:
query_embedding = get_embedding_from_array(*row)
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
distances, indices = self.index.search(query_embedding, top_k)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
results = []
for i, idx in enumerate(indices[0]):
cursor.execute(
"SELECT file_path FROM audio_files WHERE vector_id = ?",
(int(idx),),
)
row = cursor.fetchone()
if row:
results.append(
{
"path": row[0],
"distance": float(distances[0][i]),
"vector_id": int(idx),
}
)
conn.close()
return results
|