ohalkhateeb's picture
Update app.py
c5ec562 verified
raw
history blame
2.21 kB
import gradio as gr
from langchain.chains import RetrievalQAWithSourcesChain # Import RetrievalQAWithSourcesChain instead of RetrievalQA
from langchain.llms import HuggingFaceHub # Import HuggingFaceHub for Jais
from langchain.embeddings import HuggingFaceEmbeddings
import os
import preprocess # Import the preprocess module
import create_database # Import the create_database module
from langchain_chroma import Chroma # Import Chroma from langchain_chroma
# --- Preprocessing and Database Creation ---
# Preprocess data if not already done
if not os.path.exists("db"): # Check if database exists
preprocess.preprocess_and_save("./documents", "preprocessed_data.json") # Update path
create_database.create_vector_database("preprocessed_data.json", "db")
# --- RAG Pipeline ---
# Load the vector database
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma(persist_directory="db", embedding_function=embedding_model)
retriever = vector_db.as_retriever(search_kwargs={"k": 3})
# Load Jais-13B from Hugging Face
def initialize_llm():
"""Initializes the Hugging Face LLM using the HUGGINGFACEHUB_API_TOKEN environment variable."""
huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not huggingfacehub_api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
llm = HuggingFaceHub(
repo_id="jais-foundation/jais-13b-chat",
huggingfacehub_api_token=huggingfacehub_api_token,
task="text-generation",
model_kwargs={'temperature': 0.3, 'max_new_tokens': 512}
)
return llm
# Create the RetrievalQA chain
# chain_type="stuff" is required in LangChain 0.0.200 or later
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm=initialize_llm, chain_type="stuff", retriever=retriever)
# --- Gradio Interface ---
def chatbot_interface(question):
result = qa_chain({"question": question})
# Only return the answer, not the sources
return result['answer']
iface = gr.Interface(
fn=chatbot_interface,
inputs="text",
outputs="text",
title="Dubai Legislation AI Chatbot"
)
iface.launch()