RAG-SA / rag_hf.py
javiervz's picture
Update rag_hf.py
379044a verified
raw
history blame
6.3 kB
import streamlit as st
import datetime
import pickle
import numpy as np
import rdflib
import torch
import os
import requests
from rdflib import Graph as RDFGraph, Namespace
from sentence_transformers import SentenceTransformer
# === STREAMLIT UI CONFIG ===
st.set_page_config(
page_title="Atlas de Lenguas: Lenguas Indígenas Sudamericanas",
page_icon="🌍",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'About': "## Análisis con IA de lenguas indígenas en peligro\n"
"Esta aplicación integra grafos de conocimiento de Glottolog, Wikipedia y Wikidata."
}
)
# === CONFIGURATION ===
ENDPOINT_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
st.error("⚠️ No se cargó el token HF_API_TOKEN desde los Secrets.")
else:
st.success("✅ Token cargado correctamente.")
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EX = Namespace("http://example.org/lang/")
# === CUSTOM CSS ===
st.markdown("""
<style>
.tech-badge {
background-color: #ecfdf5;
color: #065f46;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
}
</style>
""", unsafe_allow_html=True)
# === CORE FUNCTIONS ===
@st.cache_resource(show_spinner="Cargando modelos de IA y grafos de conocimiento...")
def load_all_components():
embedder = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE)
methods = {}
label, suffix, ttl, matrix_path = ("LinkGraph", "_hybrid_graphsage", "grafo_ttl_hibrido_graphsage.ttl", "embed_matrix_hybrid_graphsage.npy")
with open(f"id_map{suffix}.pkl", "rb") as f:
id_map = pickle.load(f)
with open(f"grafo_embed{suffix}.pickle", "rb") as f:
G = pickle.load(f)
matrix = np.load(matrix_path)
rdf = RDFGraph()
rdf.parse(ttl, format="ttl")
methods[label] = (matrix, id_map, G, rdf)
return methods, embedder
def get_top_k(matrix, id_map, query, k, embedder):
vec = embedder.encode(f"query: {query}", convert_to_tensor=True, device=DEVICE)
vec = vec.cpu().numpy().astype("float32")
sims = np.dot(matrix, vec) / (np.linalg.norm(matrix, axis=1) * np.linalg.norm(vec) + 1e-10)
top_k_idx = np.argsort(sims)[-k:][::-1]
return [id_map[i] for i in top_k_idx]
def get_context(G, lang_id):
node = G.nodes.get(lang_id, {})
lines = [f"**Lengua:** {node.get('label', lang_id)}"]
if node.get("wikipedia_summary"):
lines.append(f"**Wikipedia:** {node['wikipedia_summary']}")
if node.get("wikidata_description"):
lines.append(f"**Wikidata:** {node['wikidata_description']}")
if node.get("wikidata_countries"):
lines.append(f"**Países:** {node['wikidata_countries']}")
return "\n\n".join(lines)
def query_rdf(rdf, lang_id):
q = f"""
PREFIX ex: <http://example.org/lang/>
SELECT ?property ?value WHERE {{ ex:{lang_id} ?property ?value }}
"""
try:
return [(str(row[0]).split("/")[-1], str(row[1])) for row in rdf.query(q)]
except Exception as e:
return [("error", str(e))]
def query_llm(prompt):
try:
res = requests.post(
ENDPOINT_URL,
headers={"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"},
json={"inputs": prompt}, timeout=60
)
res.raise_for_status()
out = res.json()
if isinstance(out, list):
if len(out) > 0 and isinstance(out[0], dict) and "generated_text" in out[0]:
return out[0]["generated_text"].strip()
elif isinstance(out, dict) and "generated_text" in out:
return out["generated_text"].strip()
return "Sin respuesta del modelo."
except Exception as e:
return f"Error al consultar el modelo: {str(e)}"
def generate_response(matrix, id_map, G, rdf, user_question, k, embedder):
ids = get_top_k(matrix, id_map, user_question, k, embedder)
context = [get_context(G, i) for i in ids]
rdf_facts = []
for i in ids:
rdf_facts.extend([f"{p}: {v}" for p, v in query_rdf(rdf, i)])
prompt_es = (
"Eres un experto en lenguas indígenas sudamericanas.\n"
"Usa solo la información del contexto y hechos RDF siguientes.\n\n"
+ "### CONTEXTO:\n" + "\n".join(context) + "\n\n"
+ "### RELACIONES RDF:\n" + "\n".join(rdf_facts) + "\n\n"
+ f"### PREGUNTA:\n{user_question}\n\nRespuesta breve en español:"
)
prompt_en = (
"You are an expert in South American indigenous languages.\n"
"Use only the following context and RDF facts to answer.\n\n"
+ "### CONTEXT:\n" + "\n".join(context) + "\n\n"
+ "### RDF RELATIONS:\n" + "\n".join(rdf_facts) + "\n\n"
+ f"### QUESTION:\n{user_question}\n\nShort answer in English:"
)
response_es = query_llm(prompt_es)
response_en = query_llm(prompt_en)
full_response = (
f"<b>Respuesta en español:</b><br>{response_es}<br><br>"
f"<b>Answer in English:</b><br>{response_en}"
)
return full_response, ids, context, rdf_facts
def main():
methods, embedder = load_all_components()
st.title("Atlas de Lenguas: Lenguas Indígenas Sudamericanas")
st.markdown("<span class='tech-badge'>Correo: jxvera@gmail.com</span>", unsafe_allow_html=True)
query = st.text_input("Escribe tu pregunta sobre lenguas indígenas:")
k = st.slider("Número de lenguas similares a recuperar", min_value=1, max_value=10, value=3)
if st.button("Analizar"):
method = methods["LinkGraph"]
start = datetime.datetime.now()
response, lang_ids, context, rdf_data = generate_response(*method, query, k, embedder)
duration = (datetime.datetime.now() - start).total_seconds()
st.markdown(response, unsafe_allow_html=True)
st.caption(f"⏱️ {duration:.2f} segundos | 🌐 {len(lang_ids)} idiomas analizados")
with st.expander("📖 Contexto"):
for ctx in context:
st.markdown(ctx)
with st.expander("🔗 Hechos RDF"):
st.code("\n".join(rdf_data))
if __name__ == "__main__":
main()