Javier Vera
commited on
Update rag_hf.py
Browse files
rag_hf.py
CHANGED
@@ -1,26 +1,24 @@
|
|
1 |
-
# rag_interface.py (Hybrid & GraphSAGE only, simplified explanations, renamed methods)
|
2 |
import streamlit as st
|
|
|
3 |
import pickle
|
4 |
import numpy as np
|
5 |
import rdflib
|
6 |
import torch
|
7 |
-
import datetime
|
8 |
import os
|
9 |
import requests
|
10 |
from rdflib import Graph as RDFGraph, Namespace
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from dotenv import load_dotenv
|
13 |
|
14 |
-
# === CONFIGURATION ===
|
15 |
load_dotenv()
|
16 |
-
|
17 |
ENDPOINT_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
18 |
-
|
19 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
20 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
|
21 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
EX = Namespace("http://example.org/lang/")
|
23 |
|
|
|
24 |
st.set_page_config(
|
25 |
page_title="Vanishing Voices: Language Atlas",
|
26 |
page_icon="🌍",
|
@@ -28,7 +26,7 @@ st.set_page_config(
|
|
28 |
initial_sidebar_state="expanded"
|
29 |
)
|
30 |
|
31 |
-
#
|
32 |
st.markdown("""
|
33 |
<style>
|
34 |
.header {
|
@@ -37,21 +35,32 @@ st.markdown("""
|
|
37 |
padding-bottom: 10px;
|
38 |
margin-bottom: 1.5rem;
|
39 |
}
|
40 |
-
.
|
41 |
-
background-color: #
|
42 |
border-radius: 8px;
|
43 |
-
padding:
|
44 |
-
margin
|
45 |
border-left: 4px solid #3498db;
|
46 |
}
|
47 |
-
.
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
}
|
52 |
</style>
|
53 |
""", unsafe_allow_html=True)
|
54 |
|
|
|
55 |
@st.cache_resource(show_spinner="Loading models and indexes...")
|
56 |
def load_all_components():
|
57 |
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
@@ -72,7 +81,7 @@ def load_all_components():
|
|
72 |
|
73 |
methods, embedder = load_all_components()
|
74 |
|
75 |
-
# === CORE FUNCTIONS ===
|
76 |
def get_top_k(matrix, id_map, query, k):
|
77 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
78 |
vec = vec.cpu().numpy().astype("float32")
|
@@ -113,17 +122,12 @@ Use strictly and only the information below to answer the user question in **Eng
|
|
113 |
- Do not infer or assume facts that are not explicitly stated.
|
114 |
- If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
|
115 |
- Limit your answer to 100 words.
|
116 |
-
|
117 |
-
|
118 |
### CONTEXT:
|
119 |
{chr(10).join(context)}
|
120 |
-
|
121 |
### RDF RELATIONS:
|
122 |
{chr(10).join(rdf_facts)}
|
123 |
-
|
124 |
### QUESTION:
|
125 |
{user_question}
|
126 |
-
|
127 |
Answer:
|
128 |
[/INST]"""
|
129 |
try:
|
@@ -139,33 +143,34 @@ Answer:
|
|
139 |
except Exception as e:
|
140 |
return str(e), ids, context, rdf_facts
|
141 |
|
142 |
-
# === MAIN FUNCTION ===
|
143 |
def main():
|
144 |
st.markdown("""
|
145 |
-
<
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
</div>
|
149 |
""", unsafe_allow_html=True)
|
150 |
|
151 |
with st.sidebar:
|
152 |
st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
|
153 |
-
|
154 |
-
st.markdown("### What are the methods?")
|
155 |
st.markdown("""
|
156 |
-
- **
|
157 |
-
- **
|
158 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
161 |
-
k = st.slider("How many languages to analyze?", 1, 10, 3)
|
162 |
-
show_ids = st.checkbox("Show IDs", value=True)
|
163 |
-
show_ctx = st.checkbox("Show Text Info", value=True)
|
164 |
-
show_rdf = st.checkbox("Show Extra Facts", value=True)
|
165 |
-
|
166 |
-
query = st.text_input("Ask something about South American languages:", "What languages are spoken in Perú?")
|
167 |
|
168 |
-
if st.button("Analyze")
|
169 |
col1, col2 = st.columns(2)
|
170 |
results = {}
|
171 |
for col, (label, method) in zip([col1, col2], methods.items()):
|
@@ -174,17 +179,31 @@ def main():
|
|
174 |
start = datetime.datetime.now()
|
175 |
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
176 |
duration = (datetime.datetime.now() - start).total_seconds()
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if show_ids:
|
180 |
-
st.
|
181 |
-
|
|
|
182 |
if show_ctx:
|
183 |
-
st.
|
184 |
-
|
|
|
|
|
185 |
if show_rdf:
|
186 |
-
st.
|
187 |
-
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
main()
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import datetime
|
3 |
import pickle
|
4 |
import numpy as np
|
5 |
import rdflib
|
6 |
import torch
|
|
|
7 |
import os
|
8 |
import requests
|
9 |
from rdflib import Graph as RDFGraph, Namespace
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
+
# === ORIGINAL CONFIGURATION (unchanged) ===
|
14 |
load_dotenv()
|
|
|
15 |
ENDPOINT_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
|
|
|
16 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
17 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
|
18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
EX = Namespace("http://example.org/lang/")
|
20 |
|
21 |
+
# === IMPROVED UI SETUP ===
|
22 |
st.set_page_config(
|
23 |
page_title="Vanishing Voices: Language Atlas",
|
24 |
page_icon="🌍",
|
|
|
26 |
initial_sidebar_state="expanded"
|
27 |
)
|
28 |
|
29 |
+
# Professional CSS (visual only)
|
30 |
st.markdown("""
|
31 |
<style>
|
32 |
.header {
|
|
|
35 |
padding-bottom: 10px;
|
36 |
margin-bottom: 1.5rem;
|
37 |
}
|
38 |
+
.response-card {
|
39 |
+
background-color: #f8f9fa;
|
40 |
border-radius: 8px;
|
41 |
+
padding: 1.5rem;
|
42 |
+
margin: 1rem 0;
|
43 |
border-left: 4px solid #3498db;
|
44 |
}
|
45 |
+
.language-info {
|
46 |
+
background-color: white;
|
47 |
+
border-radius: 8px;
|
48 |
+
padding: 1rem;
|
49 |
+
margin: 0.5rem 0;
|
50 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
51 |
+
}
|
52 |
+
.metric-badge {
|
53 |
+
display: inline-block;
|
54 |
+
background-color: #e8f4fc;
|
55 |
+
padding: 0.25rem 0.5rem;
|
56 |
+
border-radius: 4px;
|
57 |
+
font-size: 0.85rem;
|
58 |
+
margin-right: 0.5rem;
|
59 |
}
|
60 |
</style>
|
61 |
""", unsafe_allow_html=True)
|
62 |
|
63 |
+
# === ORIGINAL FUNCTIONALITY (unchanged) ===
|
64 |
@st.cache_resource(show_spinner="Loading models and indexes...")
|
65 |
def load_all_components():
|
66 |
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
|
|
|
81 |
|
82 |
methods, embedder = load_all_components()
|
83 |
|
84 |
+
# === ORIGINAL CORE FUNCTIONS (unchanged) ===
|
85 |
def get_top_k(matrix, id_map, query, k):
|
86 |
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
|
87 |
vec = vec.cpu().numpy().astype("float32")
|
|
|
122 |
- Do not infer or assume facts that are not explicitly stated.
|
123 |
- If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
|
124 |
- Limit your answer to 100 words.
|
|
|
|
|
125 |
### CONTEXT:
|
126 |
{chr(10).join(context)}
|
|
|
127 |
### RDF RELATIONS:
|
128 |
{chr(10).join(rdf_facts)}
|
|
|
129 |
### QUESTION:
|
130 |
{user_question}
|
|
|
131 |
Answer:
|
132 |
[/INST]"""
|
133 |
try:
|
|
|
143 |
except Exception as e:
|
144 |
return str(e), ids, context, rdf_facts
|
145 |
|
146 |
+
# === IMPROVED MAIN FUNCTION (same functionality, better UI) ===
|
147 |
def main():
|
148 |
st.markdown("""
|
149 |
+
<div class="header">
|
150 |
+
<h1>Vanishing Voices: South America's Endangered Language Atlas</h1>
|
151 |
+
</div>
|
152 |
+
<div style="background-color: #e8f4fc; border-radius: 8px; padding: 1rem; margin-bottom: 1.5rem;">
|
153 |
+
<b>AI-Powered Analysis:</b> This app uses Mistral-7B-Instruct with RAG (Retrieval-Augmented Generation) to analyze indigenous languages.
|
154 |
</div>
|
155 |
""", unsafe_allow_html=True)
|
156 |
|
157 |
with st.sidebar:
|
158 |
st.image("https://glottolog.org/static/img/glottolog_lod.png", width=180)
|
159 |
+
st.markdown("### Analysis Methods")
|
|
|
160 |
st.markdown("""
|
161 |
+
- **InfoMatch**: Combines text embeddings with metadata
|
162 |
+
- **LinkGraph**: Uses graph neural networks (GraphSAGE)
|
163 |
""")
|
164 |
+
|
165 |
+
# Original controls with same parameters
|
166 |
+
k = st.slider("Languages to analyze", 1, 10, 3)
|
167 |
+
show_ids = st.checkbox("Show Language IDs", True)
|
168 |
+
show_ctx = st.checkbox("Show Context Info", True)
|
169 |
+
show_rdf = st.checkbox("Show RDF Facts", False)
|
170 |
|
171 |
+
query = st.text_input("Ask about South American languages:", "What languages are spoken in Perú?")
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
if st.button("Analyze with AI"):
|
174 |
col1, col2 = st.columns(2)
|
175 |
results = {}
|
176 |
for col, (label, method) in zip([col1, col2], methods.items()):
|
|
|
179 |
start = datetime.datetime.now()
|
180 |
response, lang_ids, context, rdf_data = generate_response(*method, query, k)
|
181 |
duration = (datetime.datetime.now() - start).total_seconds()
|
182 |
+
|
183 |
+
# Improved response display
|
184 |
+
st.markdown(f"""
|
185 |
+
<div class="response-card">
|
186 |
+
{response}
|
187 |
+
<div style="margin-top: 1rem;">
|
188 |
+
<span class="metric-badge">⏱️ {duration:.2f}s</span>
|
189 |
+
<span class="metric-badge">🌐 {len(lang_ids)} languages</span>
|
190 |
+
</div>
|
191 |
+
</div>
|
192 |
+
""", unsafe_allow_html=True)
|
193 |
+
|
194 |
+
# Original debug info with better presentation
|
195 |
if show_ids:
|
196 |
+
with st.expander("Language IDs"):
|
197 |
+
st.code("\n".join(lang_ids))
|
198 |
+
|
199 |
if show_ctx:
|
200 |
+
with st.expander("Context Information"):
|
201 |
+
for ctx in context:
|
202 |
+
st.markdown(f"<div class='language-info'>{ctx}</div>", unsafe_allow_html=True)
|
203 |
+
|
204 |
if show_rdf:
|
205 |
+
with st.expander("RDF Relations"):
|
206 |
+
st.code("\n".join(rdf_data))
|
207 |
|
208 |
if __name__ == "__main__":
|
209 |
main()
|