freddyaboulton HF Staff commited on
Commit
1412907
·
1 Parent(s): e1ef382
Files changed (2) hide show
  1. app.py +21 -4
  2. audio_index.py +3 -1
app.py CHANGED
@@ -28,22 +28,39 @@ def audio_search(audio_tuple, prompt: str):
28
  array = array.astype(np.float32) / 32768.0
29
 
30
  rows = audio_embedding_system.search((sample_rate, array))
31
- print(rows)
 
32
  orig_rows = search(rows)
33
- for row in rows:
34
  path = row["path"]
35
  for orig in orig_rows:
36
  orig_row = orig["row"]
37
- print(orig_row)
38
  if orig_row["path"] == path:
39
  row["sentence"] = orig_row["sentence"]
40
  row["audio"] = [
41
  "<audio src=" + orig_row["audio"][0]["src"] + " controls />"
42
  ]
43
- return pd.DataFrame(rows)[["path", "audio", "sentence", "distance"]].sort_values(
44
  by="distance", ascending=True
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  sample_text = gr.Textbox(
48
  label="Prompt",
49
  info="Hit Enter to get a prompt from the common voice dataset",
 
28
  array = array.astype(np.float32) / 32768.0
29
 
30
  rows = audio_embedding_system.search((sample_rate, array))
31
+ least_similar = audio_embedding_system.search((sample_rate, array), least_similar=True)
32
+ rows += least_similar
33
  orig_rows = search(rows)
34
+ for i, row in enumerate(rows):
35
  path = row["path"]
36
  for orig in orig_rows:
37
  orig_row = orig["row"]
 
38
  if orig_row["path"] == path:
39
  row["sentence"] = orig_row["sentence"]
40
  row["audio"] = [
41
  "<audio src=" + orig_row["audio"][0]["src"] + " controls />"
42
  ]
43
+ df = pd.DataFrame(rows)[["path", "audio", "sentence", "distance"]].sort_values(
44
  by="distance", ascending=True
45
  )
46
 
47
+ # Define the styling function
48
+ def style_path_column(col):
49
+ n = len(col)
50
+ # Default empty styles
51
+ styles = [''] * n
52
+ for i in range(n):
53
+ # First 5 rows: green background with opacity
54
+ if i < 5:
55
+ styles[i] = 'background-color: rgba(0, 255, 0, 0.3)'
56
+ # Last 3 rows: red background with opacity
57
+ elif i >= n - 3:
58
+ styles[i] = 'background-color: rgba(255, 0, 0, 0.3)'
59
+ return styles
60
+
61
+ # Apply the styling to the 'path' column and return the Styler object
62
+ return df.style.apply(style_path_column, subset=['path'])
63
+
64
  sample_text = gr.Textbox(
65
  label="Prompt",
66
  info="Hit Enter to get a prompt from the common voice dataset",
audio_index.py CHANGED
@@ -116,7 +116,7 @@ class AudioEmbeddingSystem:
116
 
117
  return current_index_size
118
 
119
- def search(self, row: dict | tuple, top_k=5):
120
  """
121
  Search for similar audio files.
122
  Either provide query_audio (path to audio file) or query_embedding (numpy array)
@@ -127,6 +127,8 @@ class AudioEmbeddingSystem:
127
  query_embedding = get_embedding_from_array(*row)
128
 
129
  query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
 
 
130
 
131
  distances, indices = self.index.search(query_embedding, top_k)
132
 
 
116
 
117
  return current_index_size
118
 
119
+ def search(self, row: dict | tuple, top_k=5, least_similar=False):
120
  """
121
  Search for similar audio files.
122
  Either provide query_audio (path to audio file) or query_embedding (numpy array)
 
127
  query_embedding = get_embedding_from_array(*row)
128
 
129
  query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
130
+ if least_similar:
131
+ query_embedding = -1 * query_embedding
132
 
133
  distances, indices = self.index.search(query_embedding, top_k)
134