s4um1l's picture
assignment checkpoint
209e402
raw
history blame
6.52 kB
import os
from typing import List, Dict, Any
import tempfile
import shutil
import logging
import time
import traceback
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Make sure aimakerspace is in the path
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ""))
# Import from local aimakerspace module
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.embedding import EmbeddingModel
from openai import OpenAI
# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}")
class RetrievalAugmentedQAPipeline:
def __init__(self, vector_db_retriever: VectorDatabase) -> None:
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
"""
Run the RAG pipeline with the given user query.
Returns a stream of response chunks.
"""
try:
# 1. Retrieve relevant documents
logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'")
relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4)
if not relevant_docs:
logger.warning("No relevant documents found in vector database")
documents_context = "No relevant information found in the document."
else:
logger.info(f"Found {len(relevant_docs)} relevant document chunks")
# Format documents
documents_context = "\n\n".join([doc[0] for doc in relevant_docs])
# Debug similarity scores
doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)]
logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}")
# 2. Create messaging payload
messages = [
{"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context.
If the answer is not in the context, say that you don't know based on the available information.
Use the following document extracts to answer the user's question:
{documents_context}"""},
{"role": "user", "content": user_query}
]
# 3. Call LLM and stream the output
async def generate_response():
try:
logger.info("Initiating streaming completion from OpenAI")
stream = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.2,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"Error generating stream: {str(e)}")
yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}"
return {
"response": generate_response()
}
except Exception as e:
logger.error(f"Error in RAG pipeline: {str(e)}")
logger.error(traceback.format_exc())
return {
"response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"])
}
def process_file(file_path: str, file_name: str) -> List[str]:
"""Process an uploaded file and convert it to text chunks"""
logger.info(f"Processing file: {file_name} at path: {file_path}")
try:
# Determine loader based on file extension
if file_name.lower().endswith('.txt'):
logger.info(f"Using TextFileLoader for {file_name}")
loader = TextFileLoader(file_path)
loader.load()
elif file_name.lower().endswith('.pdf'):
logger.info(f"Using PDFLoader for {file_name}")
loader = PDFLoader(file_path)
loader.load()
else:
logger.warning(f"Unsupported file type: {file_name}")
return ["Unsupported file format. Please upload a .txt or .pdf file."]
# Get documents from loader
documents = loader.documents
if documents and len(documents) > 0:
logger.info(f"Loaded document with {len(documents[0])} characters")
else:
logger.warning("No document content loaded")
return ["No content found in the document"]
# Split text into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
text_chunks = text_splitter.split_texts(documents)
logger.info(f"Split document into {len(text_chunks)} chunks")
return text_chunks
except Exception as e:
logger.error(f"Error processing file: {str(e)}")
logger.error(traceback.format_exc())
return [f"Error processing file: {str(e)}"]
async def setup_vector_db(texts: List[str]) -> VectorDatabase:
"""Create vector database from text chunks"""
logger.info(f"Setting up vector database with {len(texts)} text chunks")
embedding_model = EmbeddingModel()
vector_db = VectorDatabase(embedding_model=embedding_model)
try:
await vector_db.abuild_from_list(texts)
vector_db.documents = texts
logger.info(f"Vector database built with {len(texts)} documents")
return vector_db
except Exception as e:
logger.error(f"Error setting up vector database: {str(e)}")
logger.error(traceback.format_exc())
fallback_db = VectorDatabase(embedding_model=embedding_model)
error_text = "I'm sorry, but there was an error processing the document."
fallback_db.insert(error_text, [0.0] * 1536)
fallback_db.documents = [error_text]
return fallback_db