File size: 4,849 Bytes
1e5834c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775cbf6
1e5834c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3277a4
1e5834c
 
 
 
 
 
 
 
 
 
 
cd01a28
1e5834c
 
 
 
 
 
1d6a544
1e5834c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import numpy as np
import faiss
import sqlite3
import torch
import librosa

import nemo.collections.asr as nemo_asr

speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
    "rimelabs/rimecaster"
)
speaker_model.freeze()


def get_embedding(row: dict) -> torch.Tensor | None:
    # Ensure audio is mono
    if row["audio"]["array"].ndim > 1:
        audio_array = librosa.to_mono(
            row["audio"]["array"].T
        )  # Transpose if shape is (samples, channels)
    else:
        audio_array = row["audio"]["array"]

    # Resample for embedding (keep original for upload)
    try:
        audio_resampled = librosa.resample(
            audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000
        )
    except Exception as e:
        print(f"Error resampling audio: {e}. Skipping embedding for row.")
        return None  # Return None if resampling fails

    audio_length = audio_resampled.shape[0]
    device = speaker_model.device
    audio_resampled = np.array([audio_resampled])  # Add batch dim for model
    audio_signal, audio_signal_len = (
        torch.tensor(audio_resampled, device=device, dtype=torch.float32),
        torch.tensor([audio_length], device=device),
    )
    _, emb = speaker_model.forward(
        input_signal=audio_signal, input_signal_length=audio_signal_len
    )
    del audio_signal, audio_signal_len, audio_resampled  # Clean up resampled audio
    return emb.detach().cpu().numpy()  # Return the tensor


def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray):
    row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}}
    return get_embedding(row)


class AudioEmbeddingSystem:
    def __init__(
        self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768
    ):
        """
        Initialize the audio embedding system.

        Args:
            model_name: HuggingFace model to use for embeddings
            db_path: Path to SQLite database
            index_path: Path to save FAISS index
            vector_dim: Dimension of embedding vectors
            use_quantization: Whether to use vector quantization (reduces size)
        """
        self.db_path = db_path
        self.index_path = index_path
        self.vector_dim = vector_dim

        self._init_db()

        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
        else:
            self.index = faiss.IndexFlatL2(vector_dim)

    def _init_db(self):
        """Initialize SQLite database with required tables"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute("""
        CREATE TABLE IF NOT EXISTS audio_files (
            id INTEGER PRIMARY KEY,
            file_path TEXT UNIQUE,
            vector_id INTEGER
        )
        """)
        conn.commit()
        conn.close()

    def extract_embedding(self, row: dict):
        """Extract embedding from audio file"""
        return get_embedding(row)

    def add_audio(self, row):
        """Add audio file to the database and index"""
        embedding = self.extract_embedding(row)

        embedding_normalized = embedding.reshape(1, -1).astype(np.float32)

        current_index_size = self.index.ntotal
        self.index.add(embedding_normalized)

        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute(
            "INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)",
            (row["path"], current_index_size),
        )
        conn.commit()
        conn.close()

        faiss.write_index(self.index, self.index_path)

        return current_index_size

    def search(self, row: dict | tuple, top_k=5):
        """
        Search for similar audio files.
        Either provide query_audio (path to audio file) or query_embedding (numpy array)
        """
        if isinstance(row, dict):
            query_embedding = self.extract_embedding(row)
        else:
            query_embedding = get_embedding_from_array(*row)

        query_embedding = query_embedding.reshape(1, -1).astype(np.float32)

        distances, indices = self.index.search(query_embedding, top_k)

        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        results = []
        for i, idx in enumerate(indices[0]):
            cursor.execute(
                "SELECT file_path FROM audio_files WHERE vector_id = ?",
                (int(idx),),
            )
            row = cursor.fetchone()
            if row:
                results.append(
                    {
                        "path": row[0],
                        "distance": float(distances[0][i]),
                        "vector_id": int(idx),
                    }
                )

        conn.close()
        return results