RAG-SA / rag_hf.py
javiervz's picture
Update rag_hf.py
87f954d verified
raw
history blame
13.2 kB
import streamlit as st
import datetime
import pickle
import numpy as np
import rdflib
import torch
import os
import requests
from rdflib import Graph as RDFGraph, Namespace
from sentence_transformers import SentenceTransformer
#from dotenv import load_dotenv
# === STREAMLIT UI CONFIG ===
st.set_page_config(
page_title="Atlas de Lenguas: Lenguas Indígenas Sudamericanas",
page_icon="🌍",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'About': "## Análisis con IA de lenguas indígenas en peligro\n"
"Esta aplicación integra grafos de conocimiento de Glottolog, Wikipedia y Wikidata."
}
)
# === CONFIGURATION ===
#load_dotenv()
ENDPOINT_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
#ENDPOINT_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
st.error("⚠️ No se cargó el token HF_API_TOKEN desde los Secrets.")
else:
st.success("✅ Token cargado correctamente.")
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EX = Namespace("http://example.org/lang/")
# === CUSTOM CSS ===
st.markdown("""
<style>
.header {
color: #2c3e50;
border-bottom: 2px solid #4f46e5;
padding-bottom: 0.5rem;
margin-bottom: 1.5rem;
}
.feature-card {
background-color: #f8fafc;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border-left: 3px solid #4f46e5;
}
.response-card {
background-color: #fdfdfd;
color: #1f2937;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 2px 6px rgba(0,0,0,0.08);
margin: 1rem 0;
font-size: 1rem;
line-height: 1.5;
}
.language-card {
background-color: #f9fafb;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
border: 1px solid #e5e7eb;
}
.sidebar-section {
margin-bottom: 1.5rem;
}
.sidebar-title {
font-weight: 600;
color: #4f46e5;
}
.suggested-question {
padding: 0.5rem;
margin: 0.25rem 0;
border-radius: 4px;
cursor: pointer;
transition: all 0.2s;
}
.suggested-question:hover {
background-color: #f1f5f9;
}
.metric-badge {
display: inline-block;
background-color: #e8f4fc;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.85rem;
margin-right: 0.5rem;
}
.tech-badge {
background-color: #ecfdf5;
color: #065f46;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
# === CORE FUNCTIONS ===
@st.cache_resource(show_spinner="Cargando modelos de IA y grafos de conocimiento...")
def load_all_components():
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
methods = {}
# Solo carga el método LinkGraph
label, suffix, ttl, matrix_path = ("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
with open(f"id_map{suffix}.pkl", "rb") as f:
id_map = pickle.load(f)
with open(f"grafo_embed{suffix}.pickle", "rb") as f:
G = pickle.load(f)
matrix = np.load(matrix_path)
rdf = RDFGraph()
rdf.parse(ttl, format="ttl")
methods[label] = (matrix, id_map, G, rdf)
return methods, embedder
def get_top_k(matrix, id_map, query, k, embedder):
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
vec = vec.cpu().numpy().astype("float32")
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
top_k_idx = np.argsort(sims)[-k:][::-1]
return [id_map[i] for i in top_k_idx]
def get_context(G, lang_id):
node = G.nodes.get(lang_id, {})
lines = [f"**Lengua:** {node.get('label', lang_id)}"]
if node.get("wikipedia_summary"):
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
if node.get("wikidata_description"):
lines.append(f"**Wikidata:** {node['wikidata_description']}")
if node.get("wikidata_countries"):
lines.append(f"**Países:** {node['wikidata_countries']}")
return "\n\n".join(lines)
def query_rdf(rdf, lang_id):
q = f"""
PREFIX ex: <http://example.org/lang/>
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
"""
try:
return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
except Exception as e:
return [("error", str(e))]
def generate_response(matrix, id_map, G, rdf, user_question, k, embedder):
ids = get_top_k(matrix, id_map, user_question, k, embedder)
context = [get_context(G, i) for i in ids]
rdf_facts = []
for i in ids:
rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
# Prompt para generar respuesta en español
prompt_es = f"""<s>[INST] Eres un experto en lenguas indígenas sudamericanas. Utiliza estricta y únicamente la información a continuación para responder la pregunta del usuario en **español**.
- No infieras ni asumas hechos que no estén explícitamente establecidos.
- Si la respuesta es desconocida o insuficiente, di "No puedo responder con los datos disponibles."
- Limita tu respuesta a 100 palabras.
### CONTEXTO: {chr(10).join(context)}
### RELACIONES RDF: {chr(10).join(rdf_facts)}
### PREGUNTA: {user_question}
Respuesta: [/INST]"""
# Prompt para generar respuesta en inglés
prompt_en = f"""<s>[INST] You are an expert in South American indigenous languages. Use strictly and only the information below to answer the user question in **English**.
- Do not infer or assume facts that are not explicitly stated.
- If the answer is unknown or insufficient, say \"I cannot answer with the available data.\"
- Limit your answer to 100 words.
### CONTEXT: {chr(10).join(context)}
### RDF RELATIONS: {chr(10).join(rdf_facts)}
### QUESTION: {user_question}
Answer: [/INST]"""
response_es = "Error al generar respuesta en español."
response_en = "Error generating response in English."
try:
# Generar respuesta en español
res_es = requests.post(
ENDPOINT_URL,
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
json={"inputs": prompt_es}, timeout=60
)
out_es = res_es.json()
if isinstance(out_es, list) and "generated_text" in out_es[0]:
# Limpiar la respuesta para asegurar buen formato de markdown
response_es = out_es[0]["generated_text"].replace(prompt_es.strip(), "").strip()
response_es = response_es.replace('\n', ' ').replace(' ', ' ').strip() # Reemplazar saltos de línea con espacios, limpiar espacios dobles
# Generar respuesta en inglés
res_en = requests.post(
ENDPOINT_URL,
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
json={"inputs": prompt_en}, timeout=60
)
out_en = res_en.json()
if isinstance(out_en, list) and "generated_text" in out_en[0]:
# Limpiar la respuesta para asegurar buen formato de markdown
response_en = out_en[0]["generated_text"].replace(prompt_en.strip(), "").strip()
response_en = response_en.replace('\n', ' ').replace(' ', ' ').strip() # Reemplazar saltos de línea con espacios, limpiar espacios dobles
# Concatenar ambas respuestas con saltos de línea explícitos para el display
full_response = (
f"<b>Respuesta en español:</b><br>" # Usamos <br> para el salto de línea HTML
f"{response_es}<br><br>" # Dos <br> para un doble salto de línea
f"<b>Answer in English:</b><br>"
f"{response_en}"
)
return full_response, ids, context, rdf_facts
except Exception as e:
return f"Ocurrió un error al generar la respuesta: {str(e)}", ids, context, rdf_facts
# === MAIN APP ===
def main():
methods, embedder = load_all_components()
st.markdown("""
<div class="header">
<h1>🌍 Atlas de Lenguas: Lenguas Indígenas Sudamericanas</h1>
</div>
""", unsafe_allow_html=True)
with st.expander("📌 **Resumen General**", expanded=True):
st.markdown("""
Esta aplicación ofrece **análisis impulsado por IA, Grafos y RAGs (GraphRAGs)** de lenguas indígenas de América del Sur,
integrando información de **Glottolog, Wikipedia y Wikidata**.
""")
#st.markdown("*Puedes preguntar en **español o inglés**, y el modelo responderá en **ambos idiomas**.*")
with st.sidebar:
st.markdown("### 📚 Información de Contacto")
st.markdown("""
- <span class="tech-badge">Correo: jxvera@gmail.com</span>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("### 🚀 Inicio Rápido")
st.markdown("""
1. **Escribe una pregunta** en el cuadro de entrada
2. **Haz clic en 'Analizar'** para obtener la respuesta
3. **Explora los resultados** con los detalles expandibles
""")
st.markdown("---")
st.markdown("### 🔍 Preguntas de Ejemplo")
questions = [
"¿Qué idiomas están en peligro en Brasil? (What languages are endangered in Brazil?)",
"¿Qué idiomas se hablan en Perú? (What languages are spoken in Perú?)",
"¿Cuáles idiomas están relacionados con el Quechua? (Which languages are related to Quechua?)",
"¿Dónde se habla el Mapudungun? (Where is Mapudungun spoken?)"
]
for q in questions:
if st.button(q, key=f"suggested_{q}", use_container_width=True):
st.session_state.query = q.split(" (")[0]
st.markdown("---")
st.markdown("### ⚙️ Detalles Técnicos")
st.markdown("""
- <span class="tech-badge">Embeddings</span> GraphSAGE
- <span class="tech-badge">Modelo de Lenguaje</span> Mistral-7B-Instruct
- <span class="tech-badge">Grafo de Conocimiento</span> Integración basada en RDF
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("### 📂 Fuentes de Datos")
st.markdown("""
- **Glottolog** (Clasificación de idiomas)
- **Wikipedia** (Resúmenes textuales)
- **Wikidata** (Hechos estructurados)
""")
st.markdown("---")
st.markdown("### 📊 Parámetros de Análisis")
k = st.slider("Número de idiomas a analizar", 1, 10, 3)
st.markdown("---")
st.markdown("### 🔧 Opciones Avanzadas")
show_ctx = st.checkbox("Mostrar información de contexto", False)
show_rdf = st.checkbox("Mostrar hechos estructurados", False)
st.markdown("### 📝 Haz una pregunta sobre lenguas indígenas")
st.markdown("*(Puedes preguntar en español o inglés, y el modelo responderá en **ambos idiomas**.)*")
query = st.text_input(
"Ingresa tu pregunta:",
value=st.session_state.get("query", ""),
label_visibility="collapsed",
placeholder="Ej. ¿Qué lenguas se hablan en Perú?"
)
if st.button("Analizar", type="primary", use_container_width=True):
if not query:
st.warning("Por favor, ingresa una pregunta")
return
label = "LinkGraph"
method = methods[label]
#st.markdown(f"#### Método {label}")
#st.caption("Embeddings de GraphSAGE que capturan patrones en el grafo de conocimiento")
start = datetime.datetime.now()
response, lang_ids, context, rdf_data = generate_response(*method, query, k, embedder)
duration = (datetime.datetime.now() - start).total_seconds()
st.markdown(f"""
<div class="response-card">
{response}
<div style="margin-top: 1rem;">
<span class="metric-badge">⏱️ {duration:.2f}s</span>
<span class="metric-badge">🌐 {len(lang_ids)} idiomas</span>
</div>
</div>
""", unsafe_allow_html=True)
if show_ctx:
with st.expander(f"📖 Contexto de {len(lang_ids)} idiomas"):
for lang_id, ctx in zip(lang_ids, context):
st.markdown(f"<div class='language-card'>{ctx}</div>", unsafe_allow_html=True)
if show_rdf:
with st.expander("🔗 Hechos estructurados (RDF)"):
st.code("\n".join(rdf_data))
st.markdown("---")
st.markdown("""
<div style="font-size: 0.8rem; color: #64748b; text-align: center;">
<b>📌 Nota:</b> Esta herramienta está diseñada para investigadores, lingüistas y preservacionistas culturales.
Para mejores resultados, usa preguntas específicas sobre idiomas, familias o regiones.
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()