Javier Vera commited on
Commit
b276dbe
·
verified ·
1 Parent(s): 43bac59

Update rag_hf.py

Browse files
Files changed (1) hide show
  1. rag_hf.py +190 -253
rag_hf.py CHANGED
@@ -1,253 +1,190 @@
1
- # rag_interface.py (with numpy instead of faiss)
2
- import streamlit as st
3
- import pickle
4
- import numpy as np
5
- import rdflib
6
- import torch
7
- import datetime
8
- import os
9
- import requests
10
- from rdflib import Graph as RDFGraph, Namespace
11
- from sentence_transformers import SentenceTransformer
12
- from dotenv import load_dotenv
13
-
14
- # === CONFIGURATION ===
15
- load_dotenv()
16
-
17
- MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
18
- EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
- EX = Namespace("http://example.org/lang/")
21
-
22
- st.set_page_config(
23
- page_title="Vanishing Voices: Language Atlas",
24
- page_icon="🌍",
25
- layout="wide",
26
- initial_sidebar_state="expanded"
27
- )
28
-
29
- # Custom CSS
30
- st.markdown("""
31
- <style>
32
- .header {
33
- color: #2c3e50;
34
- border-bottom: 2px solid #3498db;
35
- padding-bottom: 10px;
36
- margin-bottom: 1.5rem;
37
- }
38
- .info-box {
39
- background-color: #e8f4fc;
40
- border-radius: 8px;
41
- padding: 1rem;
42
- margin-bottom: 1.5rem;
43
- border-left: 4px solid #3498db;
44
- }
45
- .sidebar-section {
46
- margin-bottom: 2rem;
47
- }
48
- .sidebar-title {
49
- color: #2c3e50;
50
- font-size: 1.1rem;
51
- font-weight: 600;
52
- margin-bottom: 0.5rem;
53
- border-bottom: 1px solid #eee;
54
- padding-bottom: 0.5rem;
55
- }
56
- .method-card {
57
- background-color: #f8f9fa;
58
- border-radius: 8px;
59
- padding: 0.8rem;
60
- margin-bottom: 0.8rem;
61
- border-left: 3px solid #3498db;
62
- }
63
- .method-title {
64
- font-weight: 600;
65
- color: #3498db;
66
- margin-bottom: 0.3rem;
67
- }
68
- </style>
69
- """, unsafe_allow_html=True)
70
-
71
- @st.cache_resource(show_spinner="Loading models and indexes...")
72
- def load_all_components():
73
- embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
74
- methods = {}
75
- for label, suffix, ttl, matrix_path in [
76
- ("Standard", "", "grafo_ttl_no_hibrido.ttl", "embed_matrix.npy"),
77
- ("Hybrid", "_hybrid", "grafo_ttl_hibrido.ttl", "embed_matrix_hybrid.npy"),
78
- ("GraphSAGE", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
79
- ]:
80
- with open(f"id_map{suffix}.pkl", "rb") as f:
81
- id_map = pickle.load(f)
82
- with open(f"grafo_embed{suffix}.pickle", "rb") as f:
83
- G = pickle.load(f)
84
- matrix = np.load(matrix_path)
85
- rdf = RDFGraph()
86
- rdf.parse(ttl, format="ttl")
87
- methods[label] = (matrix, id_map, G, rdf)
88
- return methods, embedder
89
-
90
- methods, embedder = load_all_components()
91
-
92
- # === CORE FUNCTIONS ===
93
- def get_top_k(matrix, id_map, query, k):
94
- vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
95
- vec = vec.cpu().numpy().astype("float32")
96
- sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
97
- top_k_idx = np.argsort(sims)[-k:][::-1]
98
- return [id_map[i] for i in top_k_idx]
99
-
100
- def get_context(G, lang_id):
101
- node = G.nodes.get(lang_id, {})
102
- lines = [f"**Language:** {node.get('label', lang_id)}"]
103
- if node.get("wikipedia_summary"):
104
- lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
105
- if node.get("wikidata_description"):
106
- lines.append(f"**Wikidata:** {node['wikidata_description']}")
107
- if node.get("wikidata_countries"):
108
- lines.append(f"**Countries:** {node['wikidata_countries']}")
109
- return "\n\n".join(lines)
110
-
111
- def query_rdf(rdf, lang_id):
112
- q = f"""
113
- PREFIX ex: <http://example.org/lang/>
114
- SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
115
- """
116
- try:
117
- return [
118
- (str(row[0]).split("/")[-1], str(row[1]))
119
- for row in rdf.query(q)
120
- ]
121
- except Exception as e:
122
- return [("error", str(e))]
123
-
124
- def generate_response(matrix, id_map, G, rdf, user_question, k=3):
125
- ids = get_top_k(matrix, id_map, user_question, k)
126
- context = [get_context(G, i) for i in ids]
127
- rdf_facts = []
128
- for i in ids:
129
- rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
130
- prompt = f"""<s>[INST]
131
- You are an expert in South American indigenous languages.
132
- Use strictly and only the information below to answer the user question in **English**.
133
- - Do not infer or assume facts that are not explicitly stated.
134
- - If the answer is unknown or insufficient, say "I cannot answer with the available data."
135
- - Limit your answer to 100 words.
136
-
137
-
138
- ### CONTEXT:
139
- {chr(10).join(context)}
140
-
141
- ### RDF RELATIONS:
142
- {chr(10).join(rdf_facts)}
143
-
144
- ### QUESTION:
145
- {user_question}
146
-
147
- Answer:
148
- [/INST]"""
149
- try:
150
- res = requests.post(
151
- f"https://api-inference.huggingface.co/models/{MODEL_ID}",
152
- headers={"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}", "Content-Type": "application/json"},
153
- json={"inputs": prompt}, timeout=30
154
- )
155
- out = res.json()
156
- if isinstance(out, list) and "generated_text" in out[0]:
157
- return out[0]["generated_text"].replace(prompt.strip(), "").strip(), ids, context, rdf_facts
158
- return str(out), ids, context, rdf_facts
159
- except Exception as e:
160
- return str(e), ids, context, rdf_facts
161
-
162
- # === MAIN FUNCTION ===
163
- def main():
164
- st.markdown("""
165
- <h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
166
- <div class='info-box'>
167
- <b>Linguistic Emergency:</b> Over 40% of South America's indigenous languages face extinction.
168
- This tool documents these cultural treasures before they disappear forever.
169
- </div>
170
- """, unsafe_allow_html=True)
171
-
172
- with st.sidebar:
173
- st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
174
-
175
- with st.container():
176
- st.markdown('<div class="sidebar-title">About This Tool</div>', unsafe_allow_html=True)
177
- st.markdown("""
178
- <div class="method-card">
179
- <div class="method-title">Standard Search</div>
180
- Semantic retrieval based on text-only embeddings. Identifies languages using purely linguistic similarity from Wikipedia summaries and labels.
181
- </div>
182
- <div class="method-card">
183
- <div class="method-title">Hybrid Search</div>
184
- Combines semantic embeddings with structured data from knowledge graphs. Enriches language representation with contextual facts.
185
- </div>
186
- <div class="method-card">
187
- <div class="method-title">GraphSAGE Search</div>
188
- Leverages deep graph neural networks to learn relational patterns across languages. Captures complex cultural and genealogical connections.
189
- </div>
190
- """, unsafe_allow_html=True)
191
-
192
- with st.container():
193
- st.markdown('<div class="sidebar-title">Research Settings</div>', unsafe_allow_html=True)
194
- k = st.slider("Languages to analyze per query", 1, 10, 3)
195
- st.markdown("**Display Options:**")
196
- show_ids = st.checkbox("Language IDs", value=True, key="show_ids")
197
- show_ctx = st.checkbox("Cultural Context", value=True, key="show_ctx")
198
- show_rdf = st.checkbox("RDF Relations", value=True, key="show_rdf")
199
-
200
- with st.container():
201
- st.markdown('<div class="sidebar-title">Data Sources</div>', unsafe_allow_html=True)
202
- st.markdown("""
203
- - Glottolog
204
- - Wikidata
205
- - Wikipedia
206
- - Ethnologue
207
- """)
208
-
209
- query = st.text_input("Ask about indigenous languages:", "Which Amazonian languages are most at risk?")
210
-
211
- if st.button("Analyze with All Methods") and query:
212
- col1, col2, col3 = st.columns(3)
213
- results = {}
214
- for col, (label, method) in zip([col1, col2, col3], methods.items()):
215
- with col:
216
- st.subheader(f"{label} Analysis")
217
- start = datetime.datetime.now()
218
- response, lang_ids, context, rdf_data = generate_response(*method, query, k)
219
- duration = (datetime.datetime.now() - start).total_seconds()
220
- st.markdown(response)
221
- st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
222
- if show_ids:
223
- st.markdown("**Language Identifiers:**")
224
- st.code("\n".join(lang_ids))
225
- if show_ctx:
226
- st.markdown("**Cultural Context:**")
227
- st.markdown("\n\n---\n\n".join(context))
228
- if show_rdf:
229
- st.markdown("**RDF Knowledge:**")
230
- st.code("\n".join(rdf_data))
231
- results[label] = response
232
-
233
- log = f"""
234
- [{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]
235
- QUERY: {query}
236
- STANDARD:
237
- {results.get('Standard', '')}
238
-
239
- HYBRID:
240
- {results.get('Hybrid', '')}
241
-
242
- GRAPH-SAGE:
243
- {results.get('GraphSAGE', '')}
244
- {'='*60}
245
- """
246
- try:
247
- with open("language_analysis_logs.txt", "a", encoding="utf-8") as f:
248
- f.write(log)
249
- except Exception as e:
250
- st.warning(f"Failed to log: {str(e)}")
251
-
252
- if __name__ == "__main__":
253
- main()
 
1
+ # rag_interface.py (Hybrid & GraphSAGE only, simplified explanations, renamed methods)
2
+ import streamlit as st
3
+ import pickle
4
+ import numpy as np
5
+ import rdflib
6
+ import torch
7
+ import datetime
8
+ import os
9
+ import requests
10
+ from rdflib import Graph as RDFGraph, Namespace
11
+ from sentence_transformers import SentenceTransformer
12
+ from dotenv import load_dotenv
13
+
14
+ # === CONFIGURATION ===
15
+ load_dotenv()
16
+
17
+ ENDPOINT_URL = os.getenv("HF_ENDPOINT")
18
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
19
+ EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ EX = Namespace("http://example.org/lang/")
22
+
23
+ st.set_page_config(
24
+ page_title="Vanishing Voices: Language Atlas",
25
+ page_icon="🌍",
26
+ layout="wide",
27
+ initial_sidebar_state="expanded"
28
+ )
29
+
30
+ # Custom CSS
31
+ st.markdown("""
32
+ <style>
33
+ .header {
34
+ color: #2c3e50;
35
+ border-bottom: 2px solid #3498db;
36
+ padding-bottom: 10px;
37
+ margin-bottom: 1.5rem;
38
+ }
39
+ .info-box {
40
+ background-color: #e8f4fc;
41
+ border-radius: 8px;
42
+ padding: 1rem;
43
+ margin-bottom: 1.5rem;
44
+ border-left: 4px solid #3498db;
45
+ }
46
+ .sidebar-title {
47
+ font-size: 1.1rem;
48
+ font-weight: 600;
49
+ margin-top: 1rem;
50
+ }
51
+ </style>
52
+ """, unsafe_allow_html=True)
53
+
54
+ @st.cache_resource(show_spinner="Loading models and indexes...")
55
+ def load_all_components():
56
+ embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
57
+ methods = {}
58
+ for label, suffix, ttl, matrix_path in [
59
+ ("InfoMatch", "_hybrid", "grafo_ttl_hibrido.ttl", "embed_matrix_hybrid.npy"),
60
+ ("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
61
+ ]:
62
+ with open(f"id_map{suffix}.pkl", "rb") as f:
63
+ id_map = pickle.load(f)
64
+ with open(f"grafo_embed{suffix}.pickle", "rb") as f:
65
+ G = pickle.load(f)
66
+ matrix = np.load(matrix_path)
67
+ rdf = RDFGraph()
68
+ rdf.parse(ttl, format="ttl")
69
+ methods[label] = (matrix, id_map, G, rdf)
70
+ return methods, embedder
71
+
72
+ methods, embedder = load_all_components()
73
+
74
+ # === CORE FUNCTIONS ===
75
+ def get_top_k(matrix, id_map, query, k):
76
+ vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
77
+ vec = vec.cpu().numpy().astype("float32")
78
+ sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
79
+ top_k_idx = np.argsort(sims)[-k:][::-1]
80
+ return [id_map[i] for i in top_k_idx]
81
+
82
+ def get_context(G, lang_id):
83
+ node = G.nodes.get(lang_id, {})
84
+ lines = [f"**Language:** {node.get('label', lang_id)}"]
85
+ if node.get("wikipedia_summary"):
86
+ lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
87
+ if node.get("wikidata_description"):
88
+ lines.append(f"**Wikidata:** {node['wikidata_description']}")
89
+ if node.get("wikidata_countries"):
90
+ lines.append(f"**Countries:** {node['wikidata_countries']}")
91
+ return "\n\n".join(lines)
92
+
93
+ def query_rdf(rdf, lang_id):
94
+ q = f"""
95
+ PREFIX ex: <http://example.org/lang/>
96
+ SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
97
+ """
98
+ try:
99
+ return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
100
+ except Exception as e:
101
+ return [("error", str(e))]
102
+
103
+ def generate_response(matrix, id_map, G, rdf, user_question, k=3):
104
+ ids = get_top_k(matrix, id_map, user_question, k)
105
+ context = [get_context(G, i) for i in ids]
106
+ rdf_facts = []
107
+ for i in ids:
108
+ rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
109
+ prompt = f"""<s>[INST]
110
+ You are an expert in South American indigenous languages.
111
+ Use strictly and only the information below to answer the user question in **English**.
112
+ - Do not infer or assume facts that are not explicitly stated.
113
+ - If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
114
+ - Limit your answer to 100 words.
115
+
116
+
117
+ ### CONTEXT:
118
+ {chr(10).join(context)}
119
+
120
+ ### RDF RELATIONS:
121
+ {chr(10).join(rdf_facts)}
122
+
123
+ ### QUESTION:
124
+ {user_question}
125
+
126
+ Answer:
127
+ [/INST]"""
128
+ try:
129
+ res = requests.post(
130
+ ENDPOINT_URL,
131
+ headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
132
+ json={"inputs": prompt}, timeout=60
133
+ )
134
+ out = res.json()
135
+ if isinstance(out, list) and "generated_text" in out[0]:
136
+ return out[0]["generated_text"].replace(prompt.strip(), "").strip(), ids, context, rdf_facts
137
+ return str(out), ids, context, rdf_facts
138
+ except Exception as e:
139
+ return str(e), ids, context, rdf_facts
140
+
141
+ # === MAIN FUNCTION ===
142
+ def main():
143
+ st.markdown("""
144
+ <h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
145
+ <div class='info-box'>
146
+ <b>Why this matters:</b> Many indigenous languages in South America are disappearing. This app helps understand and preserve them using artificial intelligence.
147
+ </div>
148
+ """, unsafe_allow_html=True)
149
+
150
+ with st.sidebar:
151
+ st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
152
+
153
+ st.markdown("### What are the methods?")
154
+ st.markdown("""
155
+ - **Graph A**: Combines descriptions, country info, and speaker data using classic node2vec embeddings.
156
+ - **Graph B**: Uses graph learning (GraphSAGE) to detect patterns in how languages relate to each other.
157
+ """)
158
+
159
+ st.markdown("### Options")
160
+ k = st.slider("How many languages to analyze?", 1, 10, 3)
161
+ show_ids = st.checkbox("Show IDs", value=True)
162
+ show_ctx = st.checkbox("Show Text Info", value=True)
163
+ show_rdf = st.checkbox("Show Extra Facts", value=True)
164
+
165
+ query = st.text_input("Ask something about South American languages:", "What languages are spoken in Perú?")
166
+
167
+ if st.button("Analyze") and query:
168
+ col1, col2 = st.columns(2)
169
+ results = {}
170
+ for col, (label, method) in zip([col1, col2], methods.items()):
171
+ with col:
172
+ st.subheader(f"{label} Method")
173
+ start = datetime.datetime.now()
174
+ response, lang_ids, context, rdf_data = generate_response(*method, query, k)
175
+ duration = (datetime.datetime.now() - start).total_seconds()
176
+ st.markdown(response)
177
+ st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
178
+ if show_ids:
179
+ st.markdown("**Language IDs:**")
180
+ st.code("\n".join(lang_ids))
181
+ if show_ctx:
182
+ st.markdown("**Text Info:**")
183
+ st.markdown("\n\n---\n\n".join(context))
184
+ if show_rdf:
185
+ st.markdown("**Extra Facts:**")
186
+ st.code("\n".join(rdf_data))
187
+
188
+ if __name__ == "__main__":
189
+ main()
190
+