Javier Vera commited on
Commit
cf45911
·
verified ·
1 Parent(s): d5a9a0d

Update rag_hf.py

Browse files
Files changed (1) hide show
  1. rag_hf.py +129 -34
rag_hf.py CHANGED
@@ -1,4 +1,4 @@
1
- # rag_interface.py (tres métodos, descripciones técnicas)
2
  import streamlit as st
3
  import pickle
4
  import numpy as np
@@ -14,13 +14,59 @@ from dotenv import load_dotenv
14
  # === CONFIGURATION ===
15
  load_dotenv()
16
 
17
- ENDPOINT_URL = os.getenv("HF_ENDPOINT")
18
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
19
  EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  EX = Namespace("http://example.org/lang/")
22
 
23
- st.set_page_config(page_title="Vanishing Voices: Language Atlas", page_icon="🌍", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @st.cache_resource(show_spinner="Loading models and indexes...")
26
  def load_all_components():
@@ -43,6 +89,7 @@ def load_all_components():
43
 
44
  methods, embedder = load_all_components()
45
 
 
46
  def get_top_k(matrix, id_map, query, k):
47
  vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
48
  vec = vec.cpu().numpy().astype("float32")
@@ -52,14 +99,14 @@ def get_top_k(matrix, id_map, query, k):
52
 
53
  def get_context(G, lang_id):
54
  node = G.nodes.get(lang_id, {})
55
- parts = [f"**Language:** {node.get('label', lang_id)}"]
56
  if node.get("wikipedia_summary"):
57
- parts.append(f"**Wikipedia:** {node['wikipedia_summary']}")
58
  if node.get("wikidata_description"):
59
- parts.append(f"**Wikidata:** {node['wikidata_description']}")
60
  if node.get("wikidata_countries"):
61
- parts.append(f"**Countries:** {node['wikidata_countries']}")
62
- return "\n\n".join(parts)
63
 
64
  def query_rdf(rdf, lang_id):
65
  q = f"""
@@ -67,7 +114,10 @@ def query_rdf(rdf, lang_id):
67
  SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
68
  """
69
  try:
70
- return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
 
 
 
71
  except Exception as e:
72
  return [("error", str(e))]
73
 
@@ -84,6 +134,7 @@ Use strictly and only the information below to answer the user question in **Eng
84
  - If the answer is unknown or insufficient, say "I cannot answer with the available data."
85
  - Limit your answer to 100 words.
86
 
 
87
  ### CONTEXT:
88
  {chr(10).join(context)}
89
 
@@ -97,9 +148,9 @@ Answer:
97
  [/INST]"""
98
  try:
99
  res = requests.post(
100
- ENDPOINT_URL,
101
- headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
102
- json={"inputs": prompt}, timeout=60
103
  )
104
  out = res.json()
105
  if isinstance(out, list) and "generated_text" in out[0]:
@@ -108,6 +159,7 @@ Answer:
108
  except Exception as e:
109
  return str(e), ids, context, rdf_facts
110
 
 
111
  def main():
112
  st.markdown("""
113
  <h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
@@ -120,39 +172,82 @@ def main():
120
  with st.sidebar:
121
  st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
122
 
123
- st.markdown("### About This Tool")
124
- st.markdown("""
125
- - **Standard**: Semantic search based on text-only embeddings.
126
- - **Hybrid**: Uses node2vec to combine graph structure with descriptive features.
127
- - **GraphSAGE**: Employs deep graph learning (GraphSAGE) for relational patterns.
128
- """)
129
-
130
- k = st.slider("Languages to analyze per query", 1, 10, 3)
131
- show_ids = st.checkbox("Language IDs", value=True)
132
- show_ctx = st.checkbox("Contextual Info", value=True)
133
- show_rdf = st.checkbox("RDF Relations", value=True)
134
-
135
- query = st.text_input("Ask something about South American languages:", "Which Amazonian languages are most at risk?")
136
-
137
- if st.button("Analyze"):
138
- cols = st.columns(len(methods))
139
- for col, (label, method) in zip(cols, methods.items()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  with col:
141
- st.subheader(label)
142
  start = datetime.datetime.now()
143
  response, lang_ids, context, rdf_data = generate_response(*method, query, k)
144
  duration = (datetime.datetime.now() - start).total_seconds()
145
  st.markdown(response)
146
  st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
147
  if show_ids:
148
- st.markdown("**IDs:**")
149
  st.code("\n".join(lang_ids))
150
  if show_ctx:
151
- st.markdown("**Context:**")
152
  st.markdown("\n\n---\n\n".join(context))
153
  if show_rdf:
154
- st.markdown("**RDF:**")
155
  st.code("\n".join(rdf_data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  if __name__ == "__main__":
158
  main()
 
1
+ # rag_interface.py (with numpy instead of faiss)
2
  import streamlit as st
3
  import pickle
4
  import numpy as np
 
14
  # === CONFIGURATION ===
15
  load_dotenv()
16
 
17
+ MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
 
18
  EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
  EX = Namespace("http://example.org/lang/")
21
 
22
+ st.set_page_config(
23
+ page_title="Vanishing Voices: Language Atlas",
24
+ page_icon="🌍",
25
+ layout="wide",
26
+ initial_sidebar_state="expanded"
27
+ )
28
+
29
+ # Custom CSS
30
+ st.markdown("""
31
+ <style>
32
+ .header {
33
+ color: #2c3e50;
34
+ border-bottom: 2px solid #3498db;
35
+ padding-bottom: 10px;
36
+ margin-bottom: 1.5rem;
37
+ }
38
+ .info-box {
39
+ background-color: #e8f4fc;
40
+ border-radius: 8px;
41
+ padding: 1rem;
42
+ margin-bottom: 1.5rem;
43
+ border-left: 4px solid #3498db;
44
+ }
45
+ .sidebar-section {
46
+ margin-bottom: 2rem;
47
+ }
48
+ .sidebar-title {
49
+ color: #2c3e50;
50
+ font-size: 1.1rem;
51
+ font-weight: 600;
52
+ margin-bottom: 0.5rem;
53
+ border-bottom: 1px solid #eee;
54
+ padding-bottom: 0.5rem;
55
+ }
56
+ .method-card {
57
+ background-color: #f8f9fa;
58
+ border-radius: 8px;
59
+ padding: 0.8rem;
60
+ margin-bottom: 0.8rem;
61
+ border-left: 3px solid #3498db;
62
+ }
63
+ .method-title {
64
+ font-weight: 600;
65
+ color: #3498db;
66
+ margin-bottom: 0.3rem;
67
+ }
68
+ </style>
69
+ """, unsafe_allow_html=True)
70
 
71
  @st.cache_resource(show_spinner="Loading models and indexes...")
72
  def load_all_components():
 
89
 
90
  methods, embedder = load_all_components()
91
 
92
+ # === CORE FUNCTIONS ===
93
  def get_top_k(matrix, id_map, query, k):
94
  vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
95
  vec = vec.cpu().numpy().astype("float32")
 
99
 
100
  def get_context(G, lang_id):
101
  node = G.nodes.get(lang_id, {})
102
+ lines = [f"**Language:** {node.get('label', lang_id)}"]
103
  if node.get("wikipedia_summary"):
104
+ lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
105
  if node.get("wikidata_description"):
106
+ lines.append(f"**Wikidata:** {node['wikidata_description']}")
107
  if node.get("wikidata_countries"):
108
+ lines.append(f"**Countries:** {node['wikidata_countries']}")
109
+ return "\n\n".join(lines)
110
 
111
  def query_rdf(rdf, lang_id):
112
  q = f"""
 
114
  SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
115
  """
116
  try:
117
+ return [
118
+ (str(row[0]).split("/")[-1], str(row[1]))
119
+ for row in rdf.query(q)
120
+ ]
121
  except Exception as e:
122
  return [("error", str(e))]
123
 
 
134
  - If the answer is unknown or insufficient, say "I cannot answer with the available data."
135
  - Limit your answer to 100 words.
136
 
137
+
138
  ### CONTEXT:
139
  {chr(10).join(context)}
140
 
 
148
  [/INST]"""
149
  try:
150
  res = requests.post(
151
+ f"https://api-inference.huggingface.co/models/{MODEL_ID}",
152
+ headers={"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}", "Content-Type": "application/json"},
153
+ json={"inputs": prompt}, timeout=30
154
  )
155
  out = res.json()
156
  if isinstance(out, list) and "generated_text" in out[0]:
 
159
  except Exception as e:
160
  return str(e), ids, context, rdf_facts
161
 
162
+ # === MAIN FUNCTION ===
163
  def main():
164
  st.markdown("""
165
  <h1 class='header'>Vanishing Voices: South America's Endangered Language Atlas</h1>
 
172
  with st.sidebar:
173
  st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
174
 
175
+ with st.container():
176
+ st.markdown('<div class="sidebar-title">About This Tool</div>', unsafe_allow_html=True)
177
+ st.markdown("""
178
+ <div class="method-card">
179
+ <div class="method-title">Standard Search</div>
180
+ Semantic retrieval based on text-only embeddings. Identifies languages using purely linguistic similarity from Wikipedia summaries and labels.
181
+ </div>
182
+ <div class="method-card">
183
+ <div class="method-title">Hybrid Search</div>
184
+ Combines semantic embeddings with structured data from knowledge graphs. Enriches language representation with contextual facts.
185
+ </div>
186
+ <div class="method-card">
187
+ <div class="method-title">GraphSAGE Search</div>
188
+ Leverages deep graph neural networks to learn relational patterns across languages. Captures complex cultural and genealogical connections.
189
+ </div>
190
+ """, unsafe_allow_html=True)
191
+
192
+ with st.container():
193
+ st.markdown('<div class="sidebar-title">Research Settings</div>', unsafe_allow_html=True)
194
+ k = st.slider("Languages to analyze per query", 1, 10, 3)
195
+ st.markdown("**Display Options:**")
196
+ show_ids = st.checkbox("Language IDs", value=True, key="show_ids")
197
+ show_ctx = st.checkbox("Cultural Context", value=True, key="show_ctx")
198
+ show_rdf = st.checkbox("RDF Relations", value=True, key="show_rdf")
199
+
200
+ with st.container():
201
+ st.markdown('<div class="sidebar-title">Data Sources</div>', unsafe_allow_html=True)
202
+ st.markdown("""
203
+ - Glottolog
204
+ - Wikidata
205
+ - Wikipedia
206
+ - Ethnologue
207
+ """)
208
+
209
+ query = st.text_input("Ask about indigenous languages:", "Which Amazonian languages are most at risk?")
210
+
211
+ if st.button("Analyze with All Methods") and query:
212
+ col1, col2, col3 = st.columns(3)
213
+ results = {}
214
+ for col, (label, method) in zip([col1, col2, col3], methods.items()):
215
  with col:
216
+ st.subheader(f"{label} Analysis")
217
  start = datetime.datetime.now()
218
  response, lang_ids, context, rdf_data = generate_response(*method, query, k)
219
  duration = (datetime.datetime.now() - start).total_seconds()
220
  st.markdown(response)
221
  st.markdown(f"⏱️ {duration:.2f}s | 🌐 {len(lang_ids)} languages")
222
  if show_ids:
223
+ st.markdown("**Language Identifiers:**")
224
  st.code("\n".join(lang_ids))
225
  if show_ctx:
226
+ st.markdown("**Cultural Context:**")
227
  st.markdown("\n\n---\n\n".join(context))
228
  if show_rdf:
229
+ st.markdown("**RDF Knowledge:**")
230
  st.code("\n".join(rdf_data))
231
+ results[label] = response
232
+
233
+ log = f"""
234
+ [{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]
235
+ QUERY: {query}
236
+ STANDARD:
237
+ {results.get('Standard', '')}
238
+
239
+ HYBRID:
240
+ {results.get('Hybrid', '')}
241
+
242
+ GRAPH-SAGE:
243
+ {results.get('GraphSAGE', '')}
244
+ {'='*60}
245
+ """
246
+ try:
247
+ with open("language_analysis_logs.txt", "a", encoding="utf-8") as f:
248
+ f.write(log)
249
+ except Exception as e:
250
+ st.warning(f"Failed to log: {str(e)}")
251
 
252
  if __name__ == "__main__":
253
  main()