Javier Vera
commited on
Update rag_hf.py
Browse files
rag_hf.py
CHANGED
@@ -1,253 +1,190 @@
|
|
1 |
-
# rag_interface.py (
|
2 |
-
import streamlit as st
|
3 |
-
import pickle
|
4 |
-
import numpy as np
|
5 |
-
import rdflib
|
6 |
-
import torch
|
7 |
-
import datetime
|
8 |
-
import os
|
9 |
-
import requests
|
10 |
-
from rdflib import Graph as RDFGraph, Namespace
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
-
from dotenv import load_dotenv
|
13 |
-
|
14 |
-
# === CONFIGURATION ===
|
15 |
-
load_dotenv()
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
]
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
def
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
with st.container():
|
193 |
-
st.markdown('<div class="sidebar-title">Research Settings</div>', unsafe_allow_html=True)
|
194 |
-
k = st.slider("Languages to analyze per query", 1, 10, 3)
|
195 |
-
st.markdown("**Display Options:**")
|
196 |
-
show_ids = st.checkbox("Language IDs", value=True, key="show_ids")
|
197 |
-
show_ctx = st.checkbox("Cultural Context", value=True, key="show_ctx")
|
198 |
-
show_rdf = st.checkbox("RDF Relations", value=True, key="show_rdf")
|
199 |
-
|
200 |
-
with st.container():
|
201 |
-
st.markdown('<div class="sidebar-title">Data Sources</div>', unsafe_allow_html=True)
|
202 |
-
st.markdown("""
|
203 |
-
- Glottolog
|
204 |
-
- Wikidata
|
205 |
-
- Wikipedia
|
206 |
-
- Ethnologue
|
207 |
-
""")
|
208 |
-
|
209 |
-
query = st.text_input("Ask about indigenous languages:", "Which Amazonian languages are most at risk?")
|
210 |
-
|
211 |
-
if st.button("Analyze with All Methods") and query:
|
212 |
-
col1, col2, col3 = st.columns(3)
|
213 |
-
results = {}
|
214 |
-
for col, (label, method) in zip([col1, col2, col3], methods.items()):
|
215 |
-
with col:
|
216 |
-
st.subheader(f"{label} Analysis")
|
217 |
-
start = datetime.datetime.now()
|
218 |
-
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
219 |
-
duration = (datetime.datetime.now() - start).total_seconds()
|
220 |
-
st.markdown(response)
|
221 |
-
st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
|
222 |
-
if show_ids:
|
223 |
-
st.markdown("**Language Identifiers:**")
|
224 |
-
st.code("\n".join(lang_ids))
|
225 |
-
if show_ctx:
|
226 |
-
st.markdown("**Cultural Context:**")
|
227 |
-
st.markdown("\n\n---\n\n".join(context))
|
228 |
-
if show_rdf:
|
229 |
-
st.markdown("**RDF Knowledge:**")
|
230 |
-
st.code("\n".join(rdf_data))
|
231 |
-
results[label] = response
|
232 |
-
|
233 |
-
log = f"""
|
234 |
-
[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]
|
235 |
-
QUERY: {query}
|
236 |
-
STANDARD:
|
237 |
-
{results.get('Standard', '')}
|
238 |
-
|
239 |
-
HYBRID:
|
240 |
-
{results.get('Hybrid', '')}
|
241 |
-
|
242 |
-
GRAPH-SAGE:
|
243 |
-
{results.get('GraphSAGE', '')}
|
244 |
-
{'='*60}
|
245 |
-
"""
|
246 |
-
try:
|
247 |
-
with open("language_analysis_logs.txt", "a", encoding="utf-8") as f:
|
248 |
-
f.write(log)
|
249 |
-
except Exception as e:
|
250 |
-
st.warning(f"Failed to log: {str(e)}")
|
251 |
-
|
252 |
-
if __name__ == "__main__":
|
253 |
-
main()
|
|
|
1 |
+
# rag_interface.py (Hybrid & GraphSAGE only, simplified explanations, renamed methods)
|
2 |
+
import streamlit as st
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import rdflib
|
6 |
+
import torch
|
7 |
+
import datetime
|
8 |
+
import os
|
9 |
+
import requests
|
10 |
+
from rdflib import Graph as RDFGraph, Namespace
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
|
14 |
+
# === CONFIGURATION ===
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
ENDPOINT_URL = os.getenv("HF_ENDPOINT")
|
18 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
19 |
+
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
|
20 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
EX = Namespace("http://example.org/lang/")
|
22 |
+
|
23 |
+
st.set_page_config(
|
24 |
+
page_title="Vanishing Voices: Language Atlas",
|
25 |
+
page_icon="🌍",
|
26 |
+
layout="wide",
|
27 |
+
initial_sidebar_state="expanded"
|
28 |
+
)
|
29 |
+
|
30 |
+
# Custom CSS
|
31 |
+
st.markdown("""
|
32 |
+
<style>
|
33 |
+
.header {
|
34 |
+
color: #2c3e50;
|
35 |
+
border-bottom: 2px solid #3498db;
|
36 |
+
padding-bottom: 10px;
|
37 |
+
margin-bottom: 1.5rem;
|
38 |
+
}
|
39 |
+
.info-box {
|
40 |
+
background-color: #e8f4fc;
|
41 |
+
border-radius: 8px;
|
42 |
+
padding: 1rem;
|
43 |
+
margin-bottom: 1.5rem;
|
44 |
+
border-left: 4px solid #3498db;
|
45 |
+
}
|
46 |
+
.sidebar-title {
|
47 |
+
font-size: 1.1rem;
|
48 |
+
font-weight: 600;
|
49 |
+
margin-top: 1rem;
|
50 |
+
}
|
51 |
+
</style>
|
52 |
+
""", unsafe_allow_html=True)
|
53 |
+
|
54 |
+
@st.cache_resource(show_spinner="Loading models and indexes...")
|
55 |
+
def load_all_components():
|
56 |
+
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
57 |
+
methods = {}
|
58 |
+
for label, suffix, ttl, matrix_path in [
|
59 |
+
("InfoMatch", "_hybrid", "grafo_ttl_hibrido.ttl", "embed_matrix_hybrid.npy"),
|
60 |
+
("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
|
61 |
+
]:
|
62 |
+
with open(f"id_map{suffix}.pkl", "rb") as f:
|
63 |
+
id_map = pickle.load(f)
|
64 |
+
with open(f"grafo_embed{suffix}.pickle", "rb") as f:
|
65 |
+
G = pickle.load(f)
|
66 |
+
matrix = np.load(matrix_path)
|
67 |
+
rdf = RDFGraph()
|
68 |
+
rdf.parse(ttl, format="ttl")
|
69 |
+
methods[label] = (matrix, id_map, G, rdf)
|
70 |
+
return methods, embedder
|
71 |
+
|
72 |
+
methods, embedder = load_all_components()
|
73 |
+
|
74 |
+
# === CORE FUNCTIONS ===
|
75 |
+
def get_top_k(matrix, id_map, query, k):
|
76 |
+
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
77 |
+
vec = vec.cpu().numpy().astype("float32")
|
78 |
+
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
|
79 |
+
top_k_idx = np.argsort(sims)[-k:][::-1]
|
80 |
+
return [id_map[i] for i in top_k_idx]
|
81 |
+
|
82 |
+
def get_context(G, lang_id):
|
83 |
+
node = G.nodes.get(lang_id, {})
|
84 |
+
lines = [f"**Language:** {node.get('label', lang_id)}"]
|
85 |
+
if node.get("wikipedia_summary"):
|
86 |
+
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
|
87 |
+
if node.get("wikidata_description"):
|
88 |
+
lines.append(f"**Wikidata:** {node['wikidata_description']}")
|
89 |
+
if node.get("wikidata_countries"):
|
90 |
+
lines.append(f"**Countries:** {node['wikidata_countries']}")
|
91 |
+
return "\n\n".join(lines)
|
92 |
+
|
93 |
+
def query_rdf(rdf, lang_id):
|
94 |
+
q = f"""
|
95 |
+
PREFIX ex: <http://example.org/lang/>
|
96 |
+
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
|
97 |
+
"""
|
98 |
+
try:
|
99 |
+
return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
|
100 |
+
except Exception as e:
|
101 |
+
return [("error", str(e))]
|
102 |
+
|
103 |
+
def generate_response(matrix, id_map, G, rdf, user_question, k=3):
|
104 |
+
ids = get_top_k(matrix, id_map, user_question, k)
|
105 |
+
context = [get_context(G, i) for i in ids]
|
106 |
+
rdf_facts = []
|
107 |
+
for i in ids:
|
108 |
+
rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
|
109 |
+
prompt = f"""<s>[INST]
|
110 |
+
You are an expert in South American indigenous languages.
|
111 |
+
Use strictly and only the information below to answer the user question in **English**.
|
112 |
+
- Do not infer or assume facts that are not explicitly stated.
|
113 |
+
- If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
|
114 |
+
- Limit your answer to 100 words.
|
115 |
+
|
116 |
+
|
117 |
+
### CONTEXT:
|
118 |
+
{chr(10).join(context)}
|
119 |
+
|
120 |
+
### RDF RELATIONS:
|
121 |
+
{chr(10).join(rdf_facts)}
|
122 |
+
|
123 |
+
### QUESTION:
|
124 |
+
{user_question}
|
125 |
+
|
126 |
+
Answer:
|
127 |
+
[/INST]"""
|
128 |
+
try:
|
129 |
+
res = requests.post(
|
130 |
+
ENDPOINT_URL,
|
131 |
+
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
|
132 |
+
json={"inputs": prompt}, timeout=60
|
133 |
+
)
|
134 |
+
out = res.json()
|
135 |
+
if isinstance(out, list) and "generated_text" in out[0]:
|
136 |
+
return out[0]["generated_text"].replace(prompt.strip(), "").strip(), ids, context, rdf_facts
|
137 |
+
return str(out), ids, context, rdf_facts
|
138 |
+
except Exception as e:
|
139 |
+
return str(e), ids, context, rdf_facts
|
140 |
+
|
141 |
+
# === MAIN FUNCTION ===
|
142 |
+
def main():
|
143 |
+
st.markdown("""
|
144 |
+
<h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
|
145 |
+
<div class='info-box'>
|
146 |
+
<b>Why this matters:</b> Many indigenous languages in South America are disappearing. This app helps understand and preserve them using artificial intelligence.
|
147 |
+
</div>
|
148 |
+
""", unsafe_allow_html=True)
|
149 |
+
|
150 |
+
with st.sidebar:
|
151 |
+
st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
|
152 |
+
|
153 |
+
st.markdown("### What are the methods?")
|
154 |
+
st.markdown("""
|
155 |
+
- **Graph A**: Combines descriptions, country info, and speaker data using classic node2vec embeddings.
|
156 |
+
- **Graph B**: Uses graph learning (GraphSAGE) to detect patterns in how languages relate to each other.
|
157 |
+
""")
|
158 |
+
|
159 |
+
st.markdown("### Options")
|
160 |
+
k = st.slider("How many languages to analyze?", 1, 10, 3)
|
161 |
+
show_ids = st.checkbox("Show IDs", value=True)
|
162 |
+
show_ctx = st.checkbox("Show Text Info", value=True)
|
163 |
+
show_rdf = st.checkbox("Show Extra Facts", value=True)
|
164 |
+
|
165 |
+
query = st.text_input("Ask something about South American languages:", "What languages are spoken in Perú?")
|
166 |
+
|
167 |
+
if st.button("Analyze") and query:
|
168 |
+
col1, col2 = st.columns(2)
|
169 |
+
results = {}
|
170 |
+
for col, (label, method) in zip([col1, col2], methods.items()):
|
171 |
+
with col:
|
172 |
+
st.subheader(f"{label} Method")
|
173 |
+
start = datetime.datetime.now()
|
174 |
+
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
175 |
+
duration = (datetime.datetime.now() - start).total_seconds()
|
176 |
+
st.markdown(response)
|
177 |
+
st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
|
178 |
+
if show_ids:
|
179 |
+
st.markdown("**Language IDs:**")
|
180 |
+
st.code("\n".join(lang_ids))
|
181 |
+
if show_ctx:
|
182 |
+
st.markdown("**Text Info:**")
|
183 |
+
st.markdown("\n\n---\n\n".join(context))
|
184 |
+
if show_rdf:
|
185 |
+
st.markdown("**Extra Facts:**")
|
186 |
+
st.code("\n".join(rdf_data))
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
main()
|
190 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|