Upload 6 files
Browse files- app.py +271 -0
- new.tflite +3 -0
- old.tflite +3 -0
- requirements.txt +6 -0
- sentencepiece.model +3 -0
- sentences.txt +10 -0
app.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import sentencepiece as spm
|
9 |
+
|
10 |
+
# Set page title
|
11 |
+
st.set_page_config(page_title="Embedding Model Comparison", layout="wide")
|
12 |
+
|
13 |
+
# Function to load the SentencePiece tokenizer
|
14 |
+
@st.cache_resource
|
15 |
+
def load_tokenizer(tokenizer_path="sentencepiece.model"):
|
16 |
+
if not os.path.exists(tokenizer_path):
|
17 |
+
st.error(f"Tokenizer file not found: {tokenizer_path}")
|
18 |
+
return None
|
19 |
+
|
20 |
+
sp = spm.SentencePieceProcessor()
|
21 |
+
sp.load(tokenizer_path)
|
22 |
+
return sp
|
23 |
+
|
24 |
+
# Function to load a TFLite model
|
25 |
+
def load_model(model_path):
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
st.error(f"Model file not found: {model_path}")
|
28 |
+
return None
|
29 |
+
|
30 |
+
interpreter = tf.lite.Interpreter(model_path=model_path)
|
31 |
+
interpreter.allocate_tensors()
|
32 |
+
return interpreter
|
33 |
+
|
34 |
+
# Function to get embeddings from a TFLite model
|
35 |
+
def get_embedding(text, interpreter, tokenizer):
|
36 |
+
if interpreter is None or tokenizer is None:
|
37 |
+
return None, 0
|
38 |
+
|
39 |
+
# Get input and output details
|
40 |
+
input_details = interpreter.get_input_details()
|
41 |
+
output_details = interpreter.get_output_details()
|
42 |
+
|
43 |
+
# Get the expected input shape
|
44 |
+
input_shape = input_details[0]['shape']
|
45 |
+
max_seq_length = input_shape[1] if len(input_shape) > 1 else 64
|
46 |
+
|
47 |
+
# Properly tokenize the text using SentencePiece
|
48 |
+
tokens = tokenizer.encode(text, out_type=int)
|
49 |
+
|
50 |
+
# Handle padding/truncation
|
51 |
+
if len(tokens) > max_seq_length:
|
52 |
+
tokens = tokens[:max_seq_length] # Truncate
|
53 |
+
else:
|
54 |
+
tokens = tokens + [0] * (max_seq_length - len(tokens)) # Pad
|
55 |
+
|
56 |
+
# Prepare input tensor with proper shape
|
57 |
+
token_ids = np.array([tokens], dtype=np.int32)
|
58 |
+
|
59 |
+
# Set input tensor
|
60 |
+
interpreter.set_tensor(input_details[0]['index'], token_ids)
|
61 |
+
|
62 |
+
# Run inference
|
63 |
+
start_time = time.time()
|
64 |
+
interpreter.invoke()
|
65 |
+
inference_time = time.time() - start_time
|
66 |
+
|
67 |
+
# Get output tensor
|
68 |
+
embedding = interpreter.get_tensor(output_details[0]['index'])
|
69 |
+
|
70 |
+
return embedding, inference_time
|
71 |
+
|
72 |
+
# Function to load sentences from a file
|
73 |
+
def load_sentences(file_path):
|
74 |
+
if not os.path.exists(file_path):
|
75 |
+
return ["Hello world", "This is a test", "Embedding models are useful",
|
76 |
+
"TensorFlow Lite is great for mobile applications",
|
77 |
+
"Streamlit makes it easy to create web apps",
|
78 |
+
"Python is a popular programming language",
|
79 |
+
"Machine learning is an exciting field",
|
80 |
+
"Natural language processing helps computers understand human language",
|
81 |
+
"Semantic search finds meaning, not just keywords",
|
82 |
+
"Quantization reduces model size with minimal accuracy loss"]
|
83 |
+
|
84 |
+
with open(file_path, 'r') as f:
|
85 |
+
sentences = [line.strip() for line in f if line.strip()]
|
86 |
+
|
87 |
+
return sentences
|
88 |
+
|
89 |
+
# Function to find similar sentences
|
90 |
+
def find_similar_sentences(query_embedding, sentence_embeddings, sentences):
|
91 |
+
if query_embedding is None or len(sentence_embeddings) == 0:
|
92 |
+
return []
|
93 |
+
|
94 |
+
# Calculate similarity scores
|
95 |
+
similarities = cosine_similarity(query_embedding, sentence_embeddings)[0]
|
96 |
+
|
97 |
+
# Get indices sorted by similarity (descending)
|
98 |
+
sorted_indices = np.argsort(similarities)[::-1]
|
99 |
+
|
100 |
+
# Create result list
|
101 |
+
results = []
|
102 |
+
for idx in sorted_indices:
|
103 |
+
results.append({
|
104 |
+
"sentence": sentences[idx],
|
105 |
+
"similarity": similarities[idx]
|
106 |
+
})
|
107 |
+
|
108 |
+
return results
|
109 |
+
|
110 |
+
# Main application
|
111 |
+
def main():
|
112 |
+
st.title("Embedding Model Comparison")
|
113 |
+
|
114 |
+
# Sidebar for configuration
|
115 |
+
with st.sidebar:
|
116 |
+
st.header("Configuration")
|
117 |
+
old_model_path = st.text_input("Old Model Path", "old.tflite")
|
118 |
+
new_model_path = st.text_input("New Model Path", "new.tflite")
|
119 |
+
sentences_path = st.text_input("Sentences File Path", "sentences.txt")
|
120 |
+
tokenizer_path = st.text_input("Tokenizer Path", "sentencepiece.model")
|
121 |
+
|
122 |
+
# Load the tokenizer
|
123 |
+
tokenizer = load_tokenizer(tokenizer_path)
|
124 |
+
if tokenizer:
|
125 |
+
st.sidebar.success("Tokenizer loaded successfully")
|
126 |
+
st.sidebar.write(f"Vocabulary size: {tokenizer.get_piece_size()}")
|
127 |
+
else:
|
128 |
+
st.sidebar.error("Failed to load tokenizer")
|
129 |
+
return
|
130 |
+
|
131 |
+
# Load the models
|
132 |
+
st.header("Models")
|
133 |
+
col1, col2 = st.columns(2)
|
134 |
+
|
135 |
+
with col1:
|
136 |
+
st.subheader("Old Model")
|
137 |
+
old_model = load_model(old_model_path)
|
138 |
+
if old_model:
|
139 |
+
st.success("Old model loaded successfully")
|
140 |
+
old_input_details = old_model.get_input_details()
|
141 |
+
old_output_details = old_model.get_output_details()
|
142 |
+
st.write(f"Input shape: {old_input_details[0]['shape']}")
|
143 |
+
st.write(f"Output shape: {old_output_details[0]['shape']}")
|
144 |
+
|
145 |
+
with col2:
|
146 |
+
st.subheader("New Model")
|
147 |
+
new_model = load_model(new_model_path)
|
148 |
+
if new_model:
|
149 |
+
st.success("New model loaded successfully")
|
150 |
+
new_input_details = new_model.get_input_details()
|
151 |
+
new_output_details = new_model.get_output_details()
|
152 |
+
st.write(f"Input shape: {new_input_details[0]['shape']}")
|
153 |
+
st.write(f"Output shape: {new_output_details[0]['shape']}")
|
154 |
+
|
155 |
+
# Load sentences
|
156 |
+
sentences = load_sentences(sentences_path)
|
157 |
+
st.header("Sentences")
|
158 |
+
st.write(f"Loaded {len(sentences)} sentences")
|
159 |
+
if st.checkbox("Show loaded sentences"):
|
160 |
+
st.write(sentences[:10])
|
161 |
+
if len(sentences) > 10:
|
162 |
+
st.write("...")
|
163 |
+
|
164 |
+
# Pre-compute embeddings for all sentences (do this only once for efficiency)
|
165 |
+
if 'old_sentence_embeddings' not in st.session_state or st.button("Recompute Embeddings"):
|
166 |
+
st.session_state.old_sentence_embeddings = []
|
167 |
+
st.session_state.new_sentence_embeddings = []
|
168 |
+
|
169 |
+
if old_model and new_model:
|
170 |
+
progress_bar = st.progress(0)
|
171 |
+
st.write("Computing sentence embeddings...")
|
172 |
+
|
173 |
+
for i, sentence in enumerate(sentences):
|
174 |
+
if i % 10 == 0:
|
175 |
+
progress_bar.progress(i / len(sentences))
|
176 |
+
|
177 |
+
old_embedding, _ = get_embedding(sentence, old_model, tokenizer)
|
178 |
+
new_embedding, _ = get_embedding(sentence, new_model, tokenizer)
|
179 |
+
|
180 |
+
if old_embedding is not None:
|
181 |
+
st.session_state.old_sentence_embeddings.append(old_embedding[0])
|
182 |
+
|
183 |
+
if new_embedding is not None:
|
184 |
+
st.session_state.new_sentence_embeddings.append(new_embedding[0])
|
185 |
+
|
186 |
+
progress_bar.progress(1.0)
|
187 |
+
st.write("Embeddings computed!")
|
188 |
+
|
189 |
+
# Search interface
|
190 |
+
st.header("Search")
|
191 |
+
query = st.text_input("Enter a search query")
|
192 |
+
|
193 |
+
if query and old_model and new_model:
|
194 |
+
# Display tokenization for the query (for debugging)
|
195 |
+
with st.expander("View tokenization"):
|
196 |
+
tokens = tokenizer.encode(query, out_type=int)
|
197 |
+
pieces = tokenizer.encode(query, out_type=str)
|
198 |
+
st.write("Token IDs:", tokens)
|
199 |
+
st.write("Token pieces:", pieces)
|
200 |
+
|
201 |
+
# Get query embeddings
|
202 |
+
old_query_embedding, old_time = get_embedding(query, old_model, tokenizer)
|
203 |
+
new_query_embedding, new_time = get_embedding(query, new_model, tokenizer)
|
204 |
+
|
205 |
+
# Find similar sentences
|
206 |
+
old_results = find_similar_sentences(
|
207 |
+
old_query_embedding,
|
208 |
+
st.session_state.old_sentence_embeddings,
|
209 |
+
sentences
|
210 |
+
)
|
211 |
+
|
212 |
+
new_results = find_similar_sentences(
|
213 |
+
new_query_embedding,
|
214 |
+
st.session_state.new_sentence_embeddings,
|
215 |
+
sentences
|
216 |
+
)
|
217 |
+
|
218 |
+
# Add rank information
|
219 |
+
for i, result in enumerate(old_results):
|
220 |
+
result["rank"] = i + 1
|
221 |
+
|
222 |
+
for i, result in enumerate(new_results):
|
223 |
+
result["rank"] = i + 1
|
224 |
+
|
225 |
+
# Create separate dataframes
|
226 |
+
old_df = pd.DataFrame([
|
227 |
+
{"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
|
228 |
+
for r in old_results
|
229 |
+
])
|
230 |
+
|
231 |
+
new_df = pd.DataFrame([
|
232 |
+
{"Sentence": r["sentence"], "Similarity": f"{r['similarity']:.4f}", "Rank": r["rank"]}
|
233 |
+
for r in new_results
|
234 |
+
])
|
235 |
+
|
236 |
+
# Display results in two columns
|
237 |
+
st.subheader("Search Results")
|
238 |
+
col1, col2 = st.columns(2)
|
239 |
+
|
240 |
+
with col1:
|
241 |
+
st.markdown("### Old Model Results")
|
242 |
+
st.dataframe(old_df, use_container_width=True)
|
243 |
+
|
244 |
+
with col2:
|
245 |
+
st.markdown("### New Model Results")
|
246 |
+
st.dataframe(new_df, use_container_width=True)
|
247 |
+
|
248 |
+
# Show timing information
|
249 |
+
st.subheader("Inference Time")
|
250 |
+
st.write(f"Old model: {old_time * 1000:.2f} ms")
|
251 |
+
st.write(f"New model: {new_time * 1000:.2f} ms")
|
252 |
+
st.write(f"Speed improvement: {old_time / new_time:.2f}x")
|
253 |
+
|
254 |
+
# Show embedding visualizations
|
255 |
+
st.subheader("Embedding Visualizations")
|
256 |
+
col1, col2 = st.columns(2)
|
257 |
+
|
258 |
+
with col1:
|
259 |
+
st.write("Old Model Embedding (first 20 dimensions)")
|
260 |
+
st.bar_chart(pd.DataFrame({
|
261 |
+
'value': old_query_embedding[0][:20]
|
262 |
+
}))
|
263 |
+
|
264 |
+
with col2:
|
265 |
+
st.write("New Model Embedding (first 20 dimensions)")
|
266 |
+
st.bar_chart(pd.DataFrame({
|
267 |
+
'value': new_query_embedding[0][:20]
|
268 |
+
}))
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
main()
|
new.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b2515d1bd7ea23e80d38772a0412fbe7fc4c07e6a2cb4d70ff4e4ab0e091071
|
3 |
+
size 36841784
|
old.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f916e761c1bc9da3a714c55793b8f1fa347365bff2191dfd27f2589380456cc8
|
3 |
+
size 85885944
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
tensorflow
|
3 |
+
numpy
|
4 |
+
scikit-learn
|
5 |
+
pandas
|
6 |
+
sentencepiece
|
sentencepiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:128a5a86dce5ebcbb5dfe3282cfff769b5eea9708c3cd0fed49add3f7f7f1802
|
3 |
+
size 794334
|
sentences.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
I need to renew my passport and visa before my international trip next month.
|
2 |
+
The immigration officer carefully examined all my travel documents at the border checkpoint.
|
3 |
+
My suitcase contains clothes for both warm and cold weather.
|
4 |
+
Remember to make photocopies of your passport, visa, and travel insurance policy.
|
5 |
+
The hotel requires a valid ID at check-in for all guests.
|
6 |
+
My favorite travel destination is a small coastal town in Italy.
|
7 |
+
The airline sent me an email with my boarding pass and flight itinerary.
|
8 |
+
I enjoy trying local cuisine whenever I visit a new country.
|
9 |
+
Public transportation is an affordable way to get around in most European cities.
|
10 |
+
Don't forget to bring a pen to fill out customs declaration forms on the plane.
|