dejanseo commited on
Commit
6677389
·
verified ·
1 Parent(s): 2034a64

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +271 -0
  2. new.tflite +3 -0
  3. old.tflite +3 -0
  4. requirements.txt +6 -0
  5. sentencepiece.model +3 -0
  6. sentences.txt +10 -0
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ import pandas as pd
6
+ import os
7
+ import time
8
+ import sentencepiece as spm
9
+
10
+ # Set page title
11
+ st.set_page_config(page_title="Embedding Model Comparison", layout="wide")
12
+
13
+ # Function to load the SentencePiece tokenizer
14
+ @st.cache_resource
15
+ def load_tokenizer(tokenizer_path="sentencepiece.model"):
16
+ if not os.path.exists(tokenizer_path):
17
+ st.error(f"Tokenizer file not found: {tokenizer_path}")
18
+ return None
19
+
20
+ sp = spm.SentencePieceProcessor()
21
+ sp.load(tokenizer_path)
22
+ return sp
23
+
24
+ # Function to load a TFLite model
25
+ def load_model(model_path):
26
+ if not os.path.exists(model_path):
27
+ st.error(f"Model file not found: {model_path}")
28
+ return None
29
+
30
+ interpreter = tf.lite.Interpreter(model_path=model_path)
31
+ interpreter.allocate_tensors()
32
+ return interpreter
33
+
34
+ # Function to get embeddings from a TFLite model
35
+ def get_embedding(text, interpreter, tokenizer):
36
+ if interpreter is None or tokenizer is None:
37
+ return None, 0
38
+
39
+ # Get input and output details
40
+ input_details = interpreter.get_input_details()
41
+ output_details = interpreter.get_output_details()
42
+
43
+ # Get the expected input shape
44
+ input_shape = input_details[0]['shape']
45
+ max_seq_length = input_shape[1] if len(input_shape) > 1 else 64
46
+
47
+ # Properly tokenize the text using SentencePiece
48
+ tokens = tokenizer.encode(text, out_type=int)
49
+
50
+ # Handle padding/truncation
51
+ if len(tokens) > max_seq_length:
52
+ tokens = tokens[:max_seq_length] # Truncate
53
+ else:
54
+ tokens = tokens + [0] * (max_seq_length - len(tokens)) # Pad
55
+
56
+ # Prepare input tensor with proper shape
57
+ token_ids = np.array([tokens], dtype=np.int32)
58
+
59
+ # Set input tensor
60
+ interpreter.set_tensor(input_details[0]['index'], token_ids)
61
+
62
+ # Run inference
63
+ start_time = time.time()
64
+ interpreter.invoke()
65
+ inference_time = time.time() - start_time
66
+
67
+ # Get output tensor
68
+ embedding = interpreter.get_tensor(output_details[0]['index'])
69
+
70
+ return embedding, inference_time
71
+
72
+ # Function to load sentences from a file
73
+ def load_sentences(file_path):
74
+ if not os.path.exists(file_path):
75
+ return ["Hello world", "This is a test", "Embedding models are useful",
76
+ "TensorFlow Lite is great for mobile applications",
77
+ "Streamlit makes it easy to create web apps",
78
+ "Python is a popular programming language",
79
+ "Machine learning is an exciting field",
80
+ "Natural language processing helps computers understand human language",
81
+ "Semantic search finds meaning, not just keywords",
82
+ "Quantization reduces model size with minimal accuracy loss"]
83
+
84
+ with open(file_path, 'r') as f:
85
+ sentences = [line.strip() for line in f if line.strip()]
86
+
87
+ return sentences
88
+
89
+ # Function to find similar sentences
90
+ def find_similar_sentences(query_embedding, sentence_embeddings, sentences):
91
+ if query_embedding is None or len(sentence_embeddings) == 0:
92
+ return []
93
+
94
+ # Calculate similarity scores
95
+ similarities = cosine_similarity(query_embedding, sentence_embeddings)[0]
96
+
97
+ # Get indices sorted by similarity (descending)
98
+ sorted_indices = np.argsort(similarities)[::-1]
99
+
100
+ # Create result list
101
+ results = []
102
+ for idx in sorted_indices:
103
+ results.append({
104
+ "sentence": sentences[idx],
105
+ "similarity": similarities[idx]
106
+ })
107
+
108
+ return results
109
+
110
+ # Main application
111
+ def main():
112
+ st.title("Embedding Model Comparison")
113
+
114
+ # Sidebar for configuration
115
+ with st.sidebar:
116
+ st.header("Configuration")
117
+ old_model_path = st.text_input("Old Model Path", "old.tflite")
118
+ new_model_path = st.text_input("New Model Path", "new.tflite")
119
+ sentences_path = st.text_input("Sentences File Path", "sentences.txt")
120
+ tokenizer_path = st.text_input("Tokenizer Path", "sentencepiece.model")
121
+
122
+ # Load the tokenizer
123
+ tokenizer = load_tokenizer(tokenizer_path)
124
+ if tokenizer:
125
+ st.sidebar.success("Tokenizer loaded successfully")
126
+ st.sidebar.write(f"Vocabulary size: {tokenizer.get_piece_size()}")
127
+ else:
128
+ st.sidebar.error("Failed to load tokenizer")
129
+ return
130
+
131
+ # Load the models
132
+ st.header("Models")
133
+ col1, col2 = st.columns(2)
134
+
135
+ with col1:
136
+ st.subheader("Old Model")
137
+ old_model = load_model(old_model_path)
138
+ if old_model:
139
+ st.success("Old model loaded successfully")
140
+ old_input_details = old_model.get_input_details()
141
+ old_output_details = old_model.get_output_details()
142
+ st.write(f"Input shape: {old_input_details[0]['shape']}")
143
+ st.write(f"Output shape: {old_output_details[0]['shape']}")
144
+
145
+ with col2:
146
+ st.subheader("New Model")
147
+ new_model = load_model(new_model_path)
148
+ if new_model:
149
+ st.success("New model loaded successfully")
150
+ new_input_details = new_model.get_input_details()
151
+ new_output_details = new_model.get_output_details()
152
+ st.write(f"Input shape: {new_input_details[0]['shape']}")
153
+ st.write(f"Output shape: {new_output_details[0]['shape']}")
154
+
155
+ # Load sentences
156
+ sentences = load_sentences(sentences_path)
157
+ st.header("Sentences")
158
+ st.write(f"Loaded {len(sentences)} sentences")
159
+ if st.checkbox("Show loaded sentences"):
160
+ st.write(sentences[:10])
161
+ if len(sentences) > 10:
162
+ st.write("...")
163
+
164
+ # Pre-compute embeddings for all sentences (do this only once for efficiency)
165
+ if 'old_sentence_embeddings' not in st.session_state or st.button("Recompute Embeddings"):
166
+ st.session_state.old_sentence_embeddings = []
167
+ st.session_state.new_sentence_embeddings = []
168
+
169
+ if old_model and new_model:
170
+ progress_bar = st.progress(0)
171
+ st.write("Computing sentence embeddings...")
172
+
173
+ for i, sentence in enumerate(sentences):
174
+ if i % 10 == 0:
175
+ progress_bar.progress(i / len(sentences))
176
+
177
+ old_embedding, _ = get_embedding(sentence, old_model, tokenizer)
178
+ new_embedding, _ = get_embedding(sentence, new_model, tokenizer)
179
+
180
+ if old_embedding is not None:
181
+ st.session_state.old_sentence_embeddings.append(old_embedding[0])
182
+
183
+ if new_embedding is not None:
184
+ st.session_state.new_sentence_embeddings.append(new_embedding[0])
185
+
186
+ progress_bar.progress(1.0)
187
+ st.write("Embeddings computed!")
188
+
189
+ # Search interface
190
+ st.header("Search")
191
+ query = st.text_input("Enter a search query")
192
+
193
+ if query and old_model and new_model:
194
+ # Display tokenization for the query (for debugging)
195
+ with st.expander("View tokenization"):
196
+ tokens = tokenizer.encode(query, out_type=int)
197
+ pieces = tokenizer.encode(query, out_type=str)
198
+ st.write("Token IDs:", tokens)
199
+ st.write("Token pieces:", pieces)
200
+
201
+ # Get query embeddings
202
+ old_query_embedding, old_time = get_embedding(query, old_model, tokenizer)
203
+ new_query_embedding, new_time = get_embedding(query, new_model, tokenizer)
204
+
205
+ # Find similar sentences
206
+ old_results = find_similar_sentences(
207
+ old_query_embedding,
208
+ st.session_state.old_sentence_embeddings,
209
+ sentences
210
+ )
211
+
212
+ new_results = find_similar_sentences(
213
+ new_query_embedding,
214
+ st.session_state.new_sentence_embeddings,
215
+ sentences
216
+ )
217
+
218
+ # Add rank information
219
+ for i, result in enumerate(old_results):
220
+ result["rank"] = i + 1
221
+
222
+ for i, result in enumerate(new_results):
223
+ result["rank"] = i + 1
224
+
225
+ # Create separate dataframes
226
+ old_df = pd.DataFrame([
227
+ {"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
228
+ for r in old_results
229
+ ])
230
+
231
+ new_df = pd.DataFrame([
232
+ {"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
233
+ for r in new_results
234
+ ])
235
+
236
+ # Display results in two columns
237
+ st.subheader("Search Results")
238
+ col1, col2 = st.columns(2)
239
+
240
+ with col1:
241
+ st.markdown("### Old Model Results")
242
+ st.dataframe(old_df, use_container_width=True)
243
+
244
+ with col2:
245
+ st.markdown("### New Model Results")
246
+ st.dataframe(new_df, use_container_width=True)
247
+
248
+ # Show timing information
249
+ st.subheader("Inference Time")
250
+ st.write(f"Old model: {old_time * 1000:.2f} ms")
251
+ st.write(f"New model: {new_time * 1000:.2f} ms")
252
+ st.write(f"Speed improvement: {old_time / new_time:.2f}x")
253
+
254
+ # Show embedding visualizations
255
+ st.subheader("Embedding Visualizations")
256
+ col1, col2 = st.columns(2)
257
+
258
+ with col1:
259
+ st.write("Old Model Embedding (first 20 dimensions)")
260
+ st.bar_chart(pd.DataFrame({
261
+ 'value': old_query_embedding[0][:20]
262
+ }))
263
+
264
+ with col2:
265
+ st.write("New Model Embedding (first 20 dimensions)")
266
+ st.bar_chart(pd.DataFrame({
267
+ 'value': new_query_embedding[0][:20]
268
+ }))
269
+
270
+ if __name__ == "__main__":
271
+ main()
new.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b2515d1bd7ea23e80d38772a0412fbe7fc4c07e6a2cb4d70ff4e4ab0e091071
3
+ size 36841784
old.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f916e761c1bc9da3a714c55793b8f1fa347365bff2191dfd27f2589380456cc8
3
+ size 85885944
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ tensorflow
3
+ numpy
4
+ scikit-learn
5
+ pandas
6
+ sentencepiece
sentencepiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:128a5a86dce5ebcbb5dfe3282cfff769b5eea9708c3cd0fed49add3f7f7f1802
3
+ size 794334
sentences.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ I need to renew my passport and visa before my international trip next month.
2
+ The immigration officer carefully examined all my travel documents at the border checkpoint.
3
+ My suitcase contains clothes for both warm and cold weather.
4
+ Remember to make photocopies of your passport, visa, and travel insurance policy.
5
+ The hotel requires a valid ID at check-in for all guests.
6
+ My favorite travel destination is a small coastal town in Italy.
7
+ The airline sent me an email with my boarding pass and flight itinerary.
8
+ I enjoy trying local cuisine whenever I visit a new country.
9
+ Public transportation is an affordable way to get around in most European cities.
10
+ Don't forget to bring a pen to fill out customs declaration forms on the plane.