Javier Vera
commited on
Update rag_hf.py
Browse files
rag_hf.py
CHANGED
@@ -33,7 +33,6 @@ st.set_page_config(
|
|
33 |
# === CUSTOM CSS ===
|
34 |
st.markdown("""
|
35 |
<style>
|
36 |
-
/* Main styles */
|
37 |
.header {
|
38 |
color: #2c3e50;
|
39 |
border-bottom: 2px solid #4f46e5;
|
@@ -61,8 +60,6 @@ st.markdown("""
|
|
61 |
margin: 0.5rem 0;
|
62 |
border: 1px solid #e5e7eb;
|
63 |
}
|
64 |
-
|
65 |
-
/* Sidebar styles */
|
66 |
.sidebar-section {
|
67 |
margin-bottom: 1.5rem;
|
68 |
}
|
@@ -80,8 +77,6 @@ st.markdown("""
|
|
80 |
.suggested-question:hover {
|
81 |
background-color: #f1f5f9;
|
82 |
}
|
83 |
-
|
84 |
-
/* Metrics and badges */
|
85 |
.metric-badge {
|
86 |
display: inline-block;
|
87 |
background-color: #e8f4fc;
|
@@ -120,7 +115,7 @@ def load_all_components():
|
|
120 |
methods[label] = (matrix, id_map, G, rdf)
|
121 |
return methods, embedder
|
122 |
|
123 |
-
def get_top_k(matrix, id_map, query, k):
|
124 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
125 |
vec = vec.cpu().numpy().astype("float32")
|
126 |
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
|
@@ -148,8 +143,8 @@ def query_rdf(rdf, lang_id):
|
|
148 |
except Exception as e:
|
149 |
return [("error", str(e))]
|
150 |
|
151 |
-
def generate_response(matrix, id_map, G, rdf, user_question, k
|
152 |
-
ids = get_top_k(matrix, id_map, user_question, k)
|
153 |
context = [get_context(G, i) for i in ids]
|
154 |
rdf_facts = []
|
155 |
for i in ids:
|
@@ -183,17 +178,14 @@ Answer:
|
|
183 |
|
184 |
# === MAIN APP ===
|
185 |
def main():
|
186 |
-
# Load components
|
187 |
methods, embedder = load_all_components()
|
188 |
|
189 |
-
# Main header
|
190 |
st.markdown("""
|
191 |
<div class="header">
|
192 |
<h1>π Language Atlas: South American Indigenous Languages</h1>
|
193 |
</div>
|
194 |
""", unsafe_allow_html=True)
|
195 |
|
196 |
-
# Overview section
|
197 |
with st.expander("π Overview", expanded=True):
|
198 |
st.markdown("""
|
199 |
This app provides **AI-powered analysis** of endangered indigenous languages in South America,
|
@@ -226,13 +218,9 @@ def main():
|
|
226 |
</div>
|
227 |
""", unsafe_allow_html=True)
|
228 |
|
229 |
-
# Sidebar
|
230 |
with st.sidebar:
|
231 |
-
# Logo and academic info
|
232 |
st.markdown("### Departamento AcadΓ©mico de Humanidades")
|
233 |
st.markdown("---")
|
234 |
-
|
235 |
-
# Quick start guide
|
236 |
st.markdown("### π Quick Start")
|
237 |
st.markdown("""
|
238 |
1. **Type a question** in the input box
|
@@ -241,8 +229,6 @@ def main():
|
|
241 |
""")
|
242 |
|
243 |
st.markdown("---")
|
244 |
-
|
245 |
-
# Suggested questions
|
246 |
st.markdown("### π Example Queries")
|
247 |
questions = [
|
248 |
"What languages are endangered in Brazil?",
|
@@ -256,8 +242,6 @@ def main():
|
|
256 |
st.session_state.query = q
|
257 |
|
258 |
st.markdown("---")
|
259 |
-
|
260 |
-
# Technical details
|
261 |
st.markdown("### βοΈ Technical Details")
|
262 |
st.markdown("""
|
263 |
- <span class="tech-badge">Embeddings</span> Node2Vec vs. GraphSAGE
|
@@ -266,8 +250,6 @@ def main():
|
|
266 |
""", unsafe_allow_html=True)
|
267 |
|
268 |
st.markdown("---")
|
269 |
-
|
270 |
-
# Data sources
|
271 |
st.markdown("### π Data Sources")
|
272 |
st.markdown("""
|
273 |
- **Glottolog** (Language classification)
|
@@ -276,18 +258,13 @@ def main():
|
|
276 |
""")
|
277 |
|
278 |
st.markdown("---")
|
279 |
-
|
280 |
-
# Analysis parameters
|
281 |
st.markdown("### π Analysis Parameters")
|
282 |
k = st.slider("Number of languages to analyze", 1, 10, 3)
|
283 |
st.markdown("---")
|
284 |
-
|
285 |
-
# Debug options
|
286 |
st.markdown("### π§ Advanced Options")
|
287 |
show_ctx = st.checkbox("Show context information", False)
|
288 |
show_rdf = st.checkbox("Show structured facts", False)
|
289 |
|
290 |
-
# Main query interface
|
291 |
st.markdown("### π Ask About Indigenous Languages")
|
292 |
query = st.text_input(
|
293 |
"Enter your question:",
|
@@ -312,10 +289,9 @@ def main():
|
|
312 |
}[label])
|
313 |
|
314 |
start = datetime.datetime.now()
|
315 |
-
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
316 |
duration = (datetime.datetime.now() - start).total_seconds()
|
317 |
|
318 |
-
# Response display
|
319 |
st.markdown(f"""
|
320 |
<div class="response-card">
|
321 |
{response}
|
@@ -326,7 +302,6 @@ def main():
|
|
326 |
</div>
|
327 |
""", unsafe_allow_html=True)
|
328 |
|
329 |
-
# Additional information
|
330 |
if show_ctx:
|
331 |
with st.expander(f"π Context from {len(lang_ids)} languages"):
|
332 |
for lang_id, ctx in zip(lang_ids, context):
|
@@ -336,7 +311,6 @@ def main():
|
|
336 |
with st.expander("π Structured facts (RDF)"):
|
337 |
st.code("\n".join(rdf_data))
|
338 |
|
339 |
-
# Footer note
|
340 |
st.markdown("---")
|
341 |
st.markdown("""
|
342 |
<div style="font-size: 0.8rem; color: #64748b; text-align: center;">
|
|
|
33 |
# === CUSTOM CSS ===
|
34 |
st.markdown("""
|
35 |
<style>
|
|
|
36 |
.header {
|
37 |
color: #2c3e50;
|
38 |
border-bottom: 2px solid #4f46e5;
|
|
|
60 |
margin: 0.5rem 0;
|
61 |
border: 1px solid #e5e7eb;
|
62 |
}
|
|
|
|
|
63 |
.sidebar-section {
|
64 |
margin-bottom: 1.5rem;
|
65 |
}
|
|
|
77 |
.suggested-question:hover {
|
78 |
background-color: #f1f5f9;
|
79 |
}
|
|
|
|
|
80 |
.metric-badge {
|
81 |
display: inline-block;
|
82 |
background-color: #e8f4fc;
|
|
|
115 |
methods[label] = (matrix, id_map, G, rdf)
|
116 |
return methods, embedder
|
117 |
|
118 |
+
def get_top_k(matrix, id_map, query, k, embedder):
|
119 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
120 |
vec = vec.cpu().numpy().astype("float32")
|
121 |
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
|
|
|
143 |
except Exception as e:
|
144 |
return [("error", str(e))]
|
145 |
|
146 |
+
def generate_response(matrix, id_map, G, rdf, user_question, k, embedder):
|
147 |
+
ids = get_top_k(matrix, id_map, user_question, k, embedder)
|
148 |
context = [get_context(G, i) for i in ids]
|
149 |
rdf_facts = []
|
150 |
for i in ids:
|
|
|
178 |
|
179 |
# === MAIN APP ===
|
180 |
def main():
|
|
|
181 |
methods, embedder = load_all_components()
|
182 |
|
|
|
183 |
st.markdown("""
|
184 |
<div class="header">
|
185 |
<h1>π Language Atlas: South American Indigenous Languages</h1>
|
186 |
</div>
|
187 |
""", unsafe_allow_html=True)
|
188 |
|
|
|
189 |
with st.expander("π Overview", expanded=True):
|
190 |
st.markdown("""
|
191 |
This app provides **AI-powered analysis** of endangered indigenous languages in South America,
|
|
|
218 |
</div>
|
219 |
""", unsafe_allow_html=True)
|
220 |
|
|
|
221 |
with st.sidebar:
|
|
|
222 |
st.markdown("### Departamento AcadΓ©mico de Humanidades")
|
223 |
st.markdown("---")
|
|
|
|
|
224 |
st.markdown("### π Quick Start")
|
225 |
st.markdown("""
|
226 |
1. **Type a question** in the input box
|
|
|
229 |
""")
|
230 |
|
231 |
st.markdown("---")
|
|
|
|
|
232 |
st.markdown("### π Example Queries")
|
233 |
questions = [
|
234 |
"What languages are endangered in Brazil?",
|
|
|
242 |
st.session_state.query = q
|
243 |
|
244 |
st.markdown("---")
|
|
|
|
|
245 |
st.markdown("### βοΈ Technical Details")
|
246 |
st.markdown("""
|
247 |
- <span class="tech-badge">Embeddings</span> Node2Vec vs. GraphSAGE
|
|
|
250 |
""", unsafe_allow_html=True)
|
251 |
|
252 |
st.markdown("---")
|
|
|
|
|
253 |
st.markdown("### π Data Sources")
|
254 |
st.markdown("""
|
255 |
- **Glottolog** (Language classification)
|
|
|
258 |
""")
|
259 |
|
260 |
st.markdown("---")
|
|
|
|
|
261 |
st.markdown("### π Analysis Parameters")
|
262 |
k = st.slider("Number of languages to analyze", 1, 10, 3)
|
263 |
st.markdown("---")
|
|
|
|
|
264 |
st.markdown("### π§ Advanced Options")
|
265 |
show_ctx = st.checkbox("Show context information", False)
|
266 |
show_rdf = st.checkbox("Show structured facts", False)
|
267 |
|
|
|
268 |
st.markdown("### π Ask About Indigenous Languages")
|
269 |
query = st.text_input(
|
270 |
"Enter your question:",
|
|
|
289 |
}[label])
|
290 |
|
291 |
start = datetime.datetime.now()
|
292 |
+
response, lang_ids, context, rdf_data = generate_response(*method, query, k, embedder)
|
293 |
duration = (datetime.datetime.now() - start).total_seconds()
|
294 |
|
|
|
295 |
st.markdown(f"""
|
296 |
<div class="response-card">
|
297 |
{response}
|
|
|
302 |
</div>
|
303 |
""", unsafe_allow_html=True)
|
304 |
|
|
|
305 |
if show_ctx:
|
306 |
with st.expander(f"π Context from {len(lang_ids)} languages"):
|
307 |
for lang_id, ctx in zip(lang_ids, context):
|
|
|
311 |
with st.expander("π Structured facts (RDF)"):
|
312 |
st.code("\n".join(rdf_data))
|
313 |
|
|
|
314 |
st.markdown("---")
|
315 |
st.markdown("""
|
316 |
<div style="font-size: 0.8rem; color: #64748b; text-align: center;">
|