Javier Vera
commited on
Update rag_hf.py
Browse files
rag_hf.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# rag_interface.py (
|
2 |
import streamlit as st
|
3 |
import pickle
|
4 |
import numpy as np
|
@@ -14,13 +14,59 @@ from dotenv import load_dotenv
|
|
14 |
# === CONFIGURATION ===
|
15 |
load_dotenv()
|
16 |
|
17 |
-
|
18 |
-
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
19 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
|
20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
EX = Namespace("http://example.org/lang/")
|
22 |
|
23 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
@st.cache_resource(show_spinner="Loading models and indexes...")
|
26 |
def load_all_components():
|
@@ -43,6 +89,7 @@ def load_all_components():
|
|
43 |
|
44 |
methods, embedder = load_all_components()
|
45 |
|
|
|
46 |
def get_top_k(matrix, id_map, query, k):
|
47 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
48 |
vec = vec.cpu().numpy().astype("float32")
|
@@ -52,14 +99,14 @@ def get_top_k(matrix, id_map, query, k):
|
|
52 |
|
53 |
def get_context(G, lang_id):
|
54 |
node = G.nodes.get(lang_id, {})
|
55 |
-
|
56 |
if node.get("wikipedia_summary"):
|
57 |
-
|
58 |
if node.get("wikidata_description"):
|
59 |
-
|
60 |
if node.get("wikidata_countries"):
|
61 |
-
|
62 |
-
return "\n\n".join(
|
63 |
|
64 |
def query_rdf(rdf, lang_id):
|
65 |
q = f"""
|
@@ -67,7 +114,10 @@ def query_rdf(rdf, lang_id):
|
|
67 |
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
|
68 |
"""
|
69 |
try:
|
70 |
-
return [
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
return [("error", str(e))]
|
73 |
|
@@ -84,6 +134,7 @@ Use strictly and only the information below to answer the user question in **Eng
|
|
84 |
- If the answer is unknown or insufficient, say "I cannot answer with the available data."
|
85 |
- Limit your answer to 100 words.
|
86 |
|
|
|
87 |
### CONTEXT:
|
88 |
{chr(10).join(context)}
|
89 |
|
@@ -97,9 +148,9 @@ Answer:
|
|
97 |
[/INST]"""
|
98 |
try:
|
99 |
res = requests.post(
|
100 |
-
|
101 |
-
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
|
102 |
-
json={"inputs": prompt}, timeout=
|
103 |
)
|
104 |
out = res.json()
|
105 |
if isinstance(out, list) and "generated_text" in out[0]:
|
@@ -108,6 +159,7 @@ Answer:
|
|
108 |
except Exception as e:
|
109 |
return str(e), ids, context, rdf_facts
|
110 |
|
|
|
111 |
def main():
|
112 |
st.markdown("""
|
113 |
<h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
|
@@ -120,39 +172,82 @@ def main():
|
|
120 |
with st.sidebar:
|
121 |
st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
|
122 |
|
123 |
-
st.
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
with col:
|
141 |
-
st.subheader(label)
|
142 |
start = datetime.datetime.now()
|
143 |
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
144 |
duration = (datetime.datetime.now() - start).total_seconds()
|
145 |
st.markdown(response)
|
146 |
st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
|
147 |
if show_ids:
|
148 |
-
st.markdown("**
|
149 |
st.code("\n".join(lang_ids))
|
150 |
if show_ctx:
|
151 |
-
st.markdown("**Context:**")
|
152 |
st.markdown("\n\n---\n\n".join(context))
|
153 |
if show_rdf:
|
154 |
-
st.markdown("**RDF:**")
|
155 |
st.code("\n".join(rdf_data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
if __name__ == "__main__":
|
158 |
main()
|
|
|
1 |
+
# rag_interface.py (with numpy instead of faiss)
|
2 |
import streamlit as st
|
3 |
import pickle
|
4 |
import numpy as np
|
|
|
14 |
# === CONFIGURATION ===
|
15 |
load_dotenv()
|
16 |
|
17 |
+
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
|
18 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
|
19 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
EX = Namespace("http://example.org/lang/")
|
21 |
|
22 |
+
st.set_page_config(
|
23 |
+
page_title="Vanishing Voices: Language Atlas",
|
24 |
+
page_icon="🌍",
|
25 |
+
layout="wide",
|
26 |
+
initial_sidebar_state="expanded"
|
27 |
+
)
|
28 |
+
|
29 |
+
# Custom CSS
|
30 |
+
st.markdown("""
|
31 |
+
<style>
|
32 |
+
.header {
|
33 |
+
color: #2c3e50;
|
34 |
+
border-bottom: 2px solid #3498db;
|
35 |
+
padding-bottom: 10px;
|
36 |
+
margin-bottom: 1.5rem;
|
37 |
+
}
|
38 |
+
.info-box {
|
39 |
+
background-color: #e8f4fc;
|
40 |
+
border-radius: 8px;
|
41 |
+
padding: 1rem;
|
42 |
+
margin-bottom: 1.5rem;
|
43 |
+
border-left: 4px solid #3498db;
|
44 |
+
}
|
45 |
+
.sidebar-section {
|
46 |
+
margin-bottom: 2rem;
|
47 |
+
}
|
48 |
+
.sidebar-title {
|
49 |
+
color: #2c3e50;
|
50 |
+
font-size: 1.1rem;
|
51 |
+
font-weight: 600;
|
52 |
+
margin-bottom: 0.5rem;
|
53 |
+
border-bottom: 1px solid #eee;
|
54 |
+
padding-bottom: 0.5rem;
|
55 |
+
}
|
56 |
+
.method-card {
|
57 |
+
background-color: #f8f9fa;
|
58 |
+
border-radius: 8px;
|
59 |
+
padding: 0.8rem;
|
60 |
+
margin-bottom: 0.8rem;
|
61 |
+
border-left: 3px solid #3498db;
|
62 |
+
}
|
63 |
+
.method-title {
|
64 |
+
font-weight: 600;
|
65 |
+
color: #3498db;
|
66 |
+
margin-bottom: 0.3rem;
|
67 |
+
}
|
68 |
+
</style>
|
69 |
+
""", unsafe_allow_html=True)
|
70 |
|
71 |
@st.cache_resource(show_spinner="Loading models and indexes...")
|
72 |
def load_all_components():
|
|
|
89 |
|
90 |
methods, embedder = load_all_components()
|
91 |
|
92 |
+
# === CORE FUNCTIONS ===
|
93 |
def get_top_k(matrix, id_map, query, k):
|
94 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
95 |
vec = vec.cpu().numpy().astype("float32")
|
|
|
99 |
|
100 |
def get_context(G, lang_id):
|
101 |
node = G.nodes.get(lang_id, {})
|
102 |
+
lines = [f"**Language:** {node.get('label', lang_id)}"]
|
103 |
if node.get("wikipedia_summary"):
|
104 |
+
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
|
105 |
if node.get("wikidata_description"):
|
106 |
+
lines.append(f"**Wikidata:** {node['wikidata_description']}")
|
107 |
if node.get("wikidata_countries"):
|
108 |
+
lines.append(f"**Countries:** {node['wikidata_countries']}")
|
109 |
+
return "\n\n".join(lines)
|
110 |
|
111 |
def query_rdf(rdf, lang_id):
|
112 |
q = f"""
|
|
|
114 |
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
|
115 |
"""
|
116 |
try:
|
117 |
+
return [
|
118 |
+
(str(row[0]).split("/")[-1], str(row[1]))
|
119 |
+
for row in rdf.query(q)
|
120 |
+
]
|
121 |
except Exception as e:
|
122 |
return [("error", str(e))]
|
123 |
|
|
|
134 |
- If the answer is unknown or insufficient, say "I cannot answer with the available data."
|
135 |
- Limit your answer to 100 words.
|
136 |
|
137 |
+
|
138 |
### CONTEXT:
|
139 |
{chr(10).join(context)}
|
140 |
|
|
|
148 |
[/INST]"""
|
149 |
try:
|
150 |
res = requests.post(
|
151 |
+
f"https://api-inference.huggingface.co/models/{MODEL_ID}",
|
152 |
+
headers={"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}", "Content-Type": "application/json"},
|
153 |
+
json={"inputs": prompt}, timeout=30
|
154 |
)
|
155 |
out = res.json()
|
156 |
if isinstance(out, list) and "generated_text" in out[0]:
|
|
|
159 |
except Exception as e:
|
160 |
return str(e), ids, context, rdf_facts
|
161 |
|
162 |
+
# === MAIN FUNCTION ===
|
163 |
def main():
|
164 |
st.markdown("""
|
165 |
<h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
|
|
|
172 |
with st.sidebar:
|
173 |
st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
|
174 |
|
175 |
+
with st.container():
|
176 |
+
st.markdown('<div class="sidebar-title">About This Tool</div>', unsafe_allow_html=True)
|
177 |
+
st.markdown("""
|
178 |
+
<div class="method-card">
|
179 |
+
<div class="method-title">Standard Search</div>
|
180 |
+
Semantic retrieval based on text-only embeddings. Identifies languages using purely linguistic similarity from Wikipedia summaries and labels.
|
181 |
+
</div>
|
182 |
+
<div class="method-card">
|
183 |
+
<div class="method-title">Hybrid Search</div>
|
184 |
+
Combines semantic embeddings with structured data from knowledge graphs. Enriches language representation with contextual facts.
|
185 |
+
</div>
|
186 |
+
<div class="method-card">
|
187 |
+
<div class="method-title">GraphSAGE Search</div>
|
188 |
+
Leverages deep graph neural networks to learn relational patterns across languages. Captures complex cultural and genealogical connections.
|
189 |
+
</div>
|
190 |
+
""", unsafe_allow_html=True)
|
191 |
+
|
192 |
+
with st.container():
|
193 |
+
st.markdown('<div class="sidebar-title">Research Settings</div>', unsafe_allow_html=True)
|
194 |
+
k = st.slider("Languages to analyze per query", 1, 10, 3)
|
195 |
+
st.markdown("**Display Options:**")
|
196 |
+
show_ids = st.checkbox("Language IDs", value=True, key="show_ids")
|
197 |
+
show_ctx = st.checkbox("Cultural Context", value=True, key="show_ctx")
|
198 |
+
show_rdf = st.checkbox("RDF Relations", value=True, key="show_rdf")
|
199 |
+
|
200 |
+
with st.container():
|
201 |
+
st.markdown('<div class="sidebar-title">Data Sources</div>', unsafe_allow_html=True)
|
202 |
+
st.markdown("""
|
203 |
+
- Glottolog
|
204 |
+
- Wikidata
|
205 |
+
- Wikipedia
|
206 |
+
- Ethnologue
|
207 |
+
""")
|
208 |
+
|
209 |
+
query = st.text_input("Ask about indigenous languages:", "Which Amazonian languages are most at risk?")
|
210 |
+
|
211 |
+
if st.button("Analyze with All Methods") and query:
|
212 |
+
col1, col2, col3 = st.columns(3)
|
213 |
+
results = {}
|
214 |
+
for col, (label, method) in zip([col1, col2, col3], methods.items()):
|
215 |
with col:
|
216 |
+
st.subheader(f"{label} Analysis")
|
217 |
start = datetime.datetime.now()
|
218 |
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
219 |
duration = (datetime.datetime.now() - start).total_seconds()
|
220 |
st.markdown(response)
|
221 |
st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
|
222 |
if show_ids:
|
223 |
+
st.markdown("**Language Identifiers:**")
|
224 |
st.code("\n".join(lang_ids))
|
225 |
if show_ctx:
|
226 |
+
st.markdown("**Cultural Context:**")
|
227 |
st.markdown("\n\n---\n\n".join(context))
|
228 |
if show_rdf:
|
229 |
+
st.markdown("**RDF Knowledge:**")
|
230 |
st.code("\n".join(rdf_data))
|
231 |
+
results[label] = response
|
232 |
+
|
233 |
+
log = f"""
|
234 |
+
[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]
|
235 |
+
QUERY: {query}
|
236 |
+
STANDARD:
|
237 |
+
{results.get('Standard', '')}
|
238 |
+
|
239 |
+
HYBRID:
|
240 |
+
{results.get('Hybrid', '')}
|
241 |
+
|
242 |
+
GRAPH-SAGE:
|
243 |
+
{results.get('GraphSAGE', '')}
|
244 |
+
{'='*60}
|
245 |
+
"""
|
246 |
+
try:
|
247 |
+
with open("language_analysis_logs.txt", "a", encoding="utf-8") as f:
|
248 |
+
f.write(log)
|
249 |
+
except Exception as e:
|
250 |
+
st.warning(f"Failed to log: {str(e)}")
|
251 |
|
252 |
if __name__ == "__main__":
|
253 |
main()
|