ohalkhateeb commited on
Commit
c5ec562
·
verified ·
1 Parent(s): f179794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import gradio as gr
2
- from langchain.chains import RetrievalQA
3
  from langchain.llms import HuggingFaceHub # Import HuggingFaceHub for Jais
4
- from langchain.vectorstores import Chroma
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
  import os
7
  import preprocess # Import the preprocess module
8
  import create_database # Import the create_database module
 
9
 
10
 
11
  # --- Preprocessing and Database Creation ---
@@ -22,9 +22,7 @@ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-Mi
22
  vector_db = Chroma(persist_directory="db", embedding_function=embedding_model)
23
  retriever = vector_db.as_retriever(search_kwargs={"k": 3})
24
 
25
-
26
  # Load Jais-13B from Hugging Face
27
-
28
  def initialize_llm():
29
  """Initializes the Hugging Face LLM using the HUGGINGFACEHUB_API_TOKEN environment variable."""
30
  huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
@@ -41,12 +39,15 @@ def initialize_llm():
41
 
42
 
43
  # Create the RetrievalQA chain
44
- qa_chain = RetrievalQA(llm=initialize_llm, retriever=retriever)
 
45
 
46
  # --- Gradio Interface ---
47
 
48
  def chatbot_interface(question):
49
- return qa_chain.run(question)
 
 
50
 
51
  iface = gr.Interface(
52
  fn=chatbot_interface,
 
1
  import gradio as gr
2
+ from langchain.chains import RetrievalQAWithSourcesChain # Import RetrievalQAWithSourcesChain instead of RetrievalQA
3
  from langchain.llms import HuggingFaceHub # Import HuggingFaceHub for Jais
 
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  import os
6
  import preprocess # Import the preprocess module
7
  import create_database # Import the create_database module
8
+ from langchain_chroma import Chroma # Import Chroma from langchain_chroma
9
 
10
 
11
  # --- Preprocessing and Database Creation ---
 
22
  vector_db = Chroma(persist_directory="db", embedding_function=embedding_model)
23
  retriever = vector_db.as_retriever(search_kwargs={"k": 3})
24
 
 
25
  # Load Jais-13B from Hugging Face
 
26
  def initialize_llm():
27
  """Initializes the Hugging Face LLM using the HUGGINGFACEHUB_API_TOKEN environment variable."""
28
  huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
 
39
 
40
 
41
  # Create the RetrievalQA chain
42
+ # chain_type="stuff" is required in LangChain 0.0.200 or later
43
+ qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm=initialize_llm, chain_type="stuff", retriever=retriever)
44
 
45
  # --- Gradio Interface ---
46
 
47
  def chatbot_interface(question):
48
+ result = qa_chain({"question": question})
49
+ # Only return the answer, not the sources
50
+ return result['answer']
51
 
52
  iface = gr.Interface(
53
  fn=chatbot_interface,