ohalkhateeb's picture
Update app.py
e11b217 verified
raw
history blame
1.89 kB
import gradio as gr
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFaceHub # Import HuggingFaceHub for Jais
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
import os
import preprocess # Import the preprocess module
import create_database # Import the create_database module
# --- Preprocessing and Database Creation ---
# Preprocess data if not already done
if not os.path.exists("db"): # Check if database exists
preprocess.preprocess_and_save("./documents", "preprocessed_data.json") # Update path
create_database.create_vector_database("preprocessed_data.json", "db")
# --- RAG Pipeline ---
# Load the vector database
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma(persist_directory="db", embedding_function=embedding_model)
retriever = vector_db.as_retriever(search_kwargs={"k": 3})
# Load Jais-13B from Hugging Face
def initialize_llm():
"""Initializes the Hugging Face LLM using the HUGGINGFACEHUB_API_TOKEN environment variable."""
huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not huggingfacehub_api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
llm = HuggingFaceHub(
repo_id="jais-foundation/jais-13b-chat",
huggingfacehub_api_token=huggingfacehub_api_token,
task="text-generation",
model_kwargs={'temperature': 0.3, 'max_new_tokens': 512}
)
return llm
# Create the RetrievalQA chain
qa_chain = RetrievalQA(llm=llm, retriever=retriever)
# --- Gradio Interface ---
def chatbot_interface(question):
return qa_chain.run(question)
iface = gr.Interface(
fn=chatbot_interface,
inputs="text",
outputs="text",
title="Dubai Legislation AI Chatbot"
)
iface.launch()