RAG-SA / rag_hf.py
javiervz's picture
Update rag_hf.py
c683bff verified
raw
history blame
9.78 kB
import streamlit as st
import datetime
import pickle
import numpy as np
import rdflib
import torch
import os
import requests
from rdflib import Graph as RDFGraph, Namespace
from sentence_transformers import SentenceTransformer
# === STREAMLIT UI CONFIG ===
st.set_page_config(
page_title="Atlas de Lenguas: Lenguas Indígenas Sudamericanas",
page_icon="🌍",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'About': "## Análisis con IA de lenguas indígenas en peligro\n"
"Esta aplicación integra grafos de conocimiento de Glottolog, Wikipedia y Wikidata."
}
)
# === CONFIGURATION ===
ENDPOINT_URL = "https://lxgkooi70bj7diiu.us-east-1.aws.endpoints.huggingface.cloud"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
st.error("⚠️ No se cargó el token HF_API_TOKEN desde los Secrets.")
else:
st.success("✅ Token cargado correctamente.")
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EX = Namespace("http://example.org/lang/")
# === CUSTOM CSS ===
st.markdown("""
<style>
.header {
color: #2c3e50;
border-bottom: 2px solid #4f46e5;
padding-bottom: 0.5rem;
margin-bottom: 1.5rem;
}
.feature-card {
background-color: #f8fafc;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border-left: 3px solid #4f46e5;
}
.response-card {
background-color: #fdfdfd;
color: #1f2937;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 2px 6px rgba(0,0,0,0.08);
margin: 1rem 0;
font-size: 1rem;
line-height: 1.5;
}
.language-card {
background-color: #f9fafb;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border: 1px solid #e5e7eb;
}
.sidebar-section {
margin-bottom: 1.5rem;
}
.sidebar-title {
font-weight: 600;
color: #4f46e5;
}
.suggested-question {
padding: 0.5rem;
margin: 0.25rem 0;
border-radius: 4px;
cursor: pointer;
transition: all 0.2s;
}
.suggested-question:hover {
background-color: #f1f5f9;
}
.metric-badge {
display: inline-block;
background-color: #e8f4fc;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.85rem;
margin-right: 0.5rem;
}
.tech-badge {
background-color: #ecfdf5;
color: #065f46;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
# === CORE FUNCTIONS ===
@st.cache_resource(show_spinner="Cargando modelos de IA y grafos de conocimiento...")
def load_all_components():
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
methods = {}
label, suffix, ttl, matrix_path = ("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
with open(f"id_map{suffix}.pkl", "rb") as f:
id_map = pickle.load(f)
with open(f"grafo_embed{suffix}.pickle", "rb") as f:
G = pickle.load(f)
matrix = np.load(matrix_path)
rdf = RDFGraph()
rdf.parse(ttl, format="ttl")
methods[label] = (matrix, id_map, G, rdf)
return methods, embedder
def get_top_k(matrix, id_map, query, k, embedder):
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
vec = vec.cpu().numpy().astype("float32")
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
top_k_idx = np.argsort(sims)[-k:][::-1]
return [id_map[i] for i in top_k_idx]
def get_context(G, lang_id):
node = G.nodes.get(lang_id, {})
lines = [f"**Lengua:** {node.get('label', lang_id)}"]
if node.get("wikipedia_summary"):
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
if node.get("wikidata_description"):
lines.append(f"**Wikidata:** {node['wikidata_description']}")
if node.get("wikidata_countries"):
lines.append(f"**Países:** {node['wikidata_countries']}")
return "\n\n".join(lines)
def query_rdf(rdf, lang_id):
q = f"""
PREFIX ex: <http://example.org/lang/>
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
"""
try:
return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
except Exception as e:
return [("error", str(e))]
def query_llm(prompt):
try:
res = requests.post(
ENDPOINT_URL,
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
json={"inputs": prompt}, timeout=30
)
res.raise_for_status()
out = res.json()
if isinstance(out, list):
if len(out) > 0 and isinstance(out[0], dict) and "generated_text" in out[0]:
return out[0]["generated_text"].strip()
elif isinstance(out, dict) and "generated_text" in out:
return out["generated_text"].strip()
elif isinstance(out, dict) and "text" in out:
return out["text"].strip()
return "Sin respuesta del modelo."
except Exception as e:
return f"Error al consultar el modelo: {str(e)}"
def generate_response(matrix, id_map, G, rdf, user_question, k, embedder):
ids = get_top_k(matrix, id_map, user_question, k, embedder)
context = [get_context(G, i) for i in ids]
rdf_facts = []
for i in ids:
rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
prompt_es = (
"Contexto: " + " ".join(context) + "\n" +
"Hechos RDF: " + ", ".join(rdf_facts) + "\n" +
f"Pregunta: {user_question} Responde en español:"
)
prompt_en = (
"Context: " + " ".join(context) + "\n" +
"RDF facts: " + ", ".join(rdf_facts) + "\n" +
f"Question: {user_question} Answer in English:"
)
response_es = query_llm(prompt_es)
response_en = query_llm(prompt_en)
full_response = (
f"<b>Respuesta en español:</b><br>{response_es}<br><br>"
f"<b>Answer in English:</b><br>{response_en}"
)
return full_response, ids, context, rdf_facts
# === UI MAIN ===
def main():
methods, embedder = load_all_components()
st.markdown("""
<div class="header">
<h1>🌍 Atlas de Lenguas: Lenguas Indígenas Sudamericanas</h1>
</div>
""", unsafe_allow_html=True)
with st.sidebar:
st.markdown("### 📚 Información de Contacto")
st.markdown("""
- <span class="tech-badge">Correo: jxvera@gmail.com</span>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("### 🚀 Inicio Rápido")
st.markdown("""
1. **Escribe una pregunta** en el cuadro de entrada
2. **Haz clic en 'Analizar'** para obtener la respuesta
3. **Explora los resultados** con los detalles expandibles
""")
st.markdown("---")
st.markdown("### 🔍 Preguntas de Ejemplo")
questions = [
"¿Qué idiomas están en peligro en Brasil?",
"¿Qué idiomas se hablan en Perú?",
"¿Cuáles idiomas están relacionados con el Quechua?",
"¿Dónde se habla el Mapudungun?"
]
for q in questions:
if st.button(q, key=f"suggested_{q}", use_container_width=True):
st.session_state.query = q
st.markdown("---")
st.markdown("### 📊 Parámetros de Análisis")
k = st.slider("Número de idiomas a analizar", 1, 10, 3)
st.markdown("---")
st.markdown("### 🔧 Opciones Avanzadas")
show_ctx = st.checkbox("Mostrar información de contexto", False)
show_rdf = st.checkbox("Mostrar hechos estructurados", False)
st.markdown("### 📝 Haz una pregunta sobre lenguas indígenas")
query = st.text_input(
"Ingresa tu pregunta:",
value=st.session_state.get("query", ""),
label_visibility="collapsed",
placeholder="Ej. ¿Qué lenguas se hablan en Perú?"
)
if st.button("Analizar", type="primary", use_container_width=True):
if not query:
st.warning("Por favor, ingresa una pregunta")
return
label = "LinkGraph"
method = methods[label]
start = datetime.datetime.now()
response, lang_ids, context, rdf_data = generate_response(*method, query, k, embedder)
duration = (datetime.datetime.now() - start).total_seconds()
st.markdown(f"""
<div class="response-card">
{response}
<div style="margin-top: 1rem;">
<span class="metric-badge">⏱️ {duration:.2f}s</span>
<span class="metric-badge">🌐 {len(lang_ids)} idiomas</span>
</div>
</div>
""", unsafe_allow_html=True)
if show_ctx:
with st.expander(f"📖 Contexto de {len(lang_ids)} idiomas"):
for lang_id, ctx in zip(lang_ids, context):
st.markdown(f"<div class='language-card'>{ctx}</div>", unsafe_allow_html=True)
if show_rdf:
with st.expander("🔗 Hechos estructurados (RDF)"):
st.code("\n".join(rdf_data))
st.markdown("---")
st.markdown("""
<div style="font-size: 0.8rem; color: #64748b; text-align: center;">
<b>📌 Nota:</b> Esta herramienta está diseñada para investigadores, lingüistas y preservacionistas culturales.
Para mejores resultados, usa preguntas específicas sobre idiomas, familias o regiones.
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()