Spaces:
Sleeping
Sleeping
import numpy as np | |
from collections import defaultdict | |
from typing import List, Tuple, Callable, Dict | |
from aimakerspace.openai_utils.embedding import EmbeddingModel | |
import asyncio | |
import logging | |
import concurrent.futures | |
import time | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float: | |
"""Computes the cosine similarity between two vectors.""" | |
dot_product = np.dot(vector_a, vector_b) | |
norm_a = np.linalg.norm(vector_a) | |
norm_b = np.linalg.norm(vector_b) | |
return dot_product / (norm_a * norm_b) | |
class VectorDatabase: | |
def __init__(self, embedding_model: EmbeddingModel = None, batch_size: int = 25): | |
self.vectors = defaultdict(np.array) | |
self.embedding_model = embedding_model or EmbeddingModel() | |
self.batch_size = batch_size # Process embeddings in batches for better performance | |
def insert(self, key: str, vector: np.array) -> None: | |
self.vectors[key] = vector | |
def search( | |
self, | |
query_vector: np.array, | |
k: int, | |
distance_measure: Callable = cosine_similarity, | |
) -> List[Tuple[str, float]]: | |
scores = [ | |
(key, distance_measure(query_vector, vector)) | |
for key, vector in self.vectors.items() | |
] | |
return sorted(scores, key=lambda x: x[1], reverse=True)[:k] | |
def search_by_text( | |
self, | |
query_text: str, | |
k: int, | |
distance_measure: Callable = cosine_similarity, | |
return_as_text: bool = False, | |
) -> List[Tuple[str, float]]: | |
query_vector = self.embedding_model.get_embedding(query_text) | |
results = self.search(query_vector, k, distance_measure) | |
return [result[0] for result in results] if return_as_text else results | |
def retrieve_from_key(self, key: str) -> np.array: | |
return self.vectors.get(key, None) | |
async def abuild_from_list(self, list_of_text: List[str]) -> "VectorDatabase": | |
start_time = time.time() | |
if not list_of_text: | |
logger.warning("Empty list provided to build vector database") | |
return self | |
logger.info(f"Building embeddings for {len(list_of_text)} text chunks in batches of {self.batch_size}") | |
# Process in batches to avoid overwhelming the API | |
batches = [list_of_text[i:i + self.batch_size] for i in range(0, len(list_of_text), self.batch_size)] | |
logger.info(f"Split into {len(batches)} batches") | |
for i, batch in enumerate(batches): | |
batch_start = time.time() | |
logger.info(f"Processing batch {i+1}/{len(batches)} with {len(batch)} text chunks") | |
try: | |
# Get embeddings for this batch | |
embeddings = await self.embedding_model.async_get_embeddings(batch) | |
# Insert into vector database | |
for text, embedding in zip(batch, embeddings): | |
self.insert(text, np.array(embedding)) | |
batch_duration = time.time() - batch_start | |
logger.info(f"Batch {i+1} completed in {batch_duration:.2f}s") | |
# Small delay between batches to avoid rate limiting | |
if i < len(batches) - 1: | |
await asyncio.sleep(0.5) | |
except Exception as e: | |
logger.error(f"Error processing batch {i+1}: {str(e)}") | |
# Continue with next batch even if this one failed | |
total_duration = time.time() - start_time | |
logger.info(f"Vector database built with {len(self.vectors)} vectors in {total_duration:.2f}s") | |
return self | |
if __name__ == "__main__": | |
list_of_text = [ | |
"I like to eat broccoli and bananas.", | |
"I ate a banana and spinach smoothie for breakfast.", | |
"Chinchillas and kittens are cute.", | |
"My sister adopted a kitten yesterday.", | |
"Look at this cute hamster munching on a piece of broccoli.", | |
] | |
vector_db = VectorDatabase() | |
vector_db = asyncio.run(vector_db.abuild_from_list(list_of_text)) | |
k = 2 | |
searched_vector = vector_db.search_by_text("I think fruit is awesome!", k=k) | |
print(f"Closest {k} vector(s):", searched_vector) | |
retrieved_vector = vector_db.retrieve_from_key( | |
"I like to eat broccoli and bananas." | |
) | |
print("Retrieved vector:", retrieved_vector) | |
relevant_texts = vector_db.search_by_text( | |
"I think fruit is awesome!", k=k, return_as_text=True | |
) | |
print(f"Closest {k} text(s):", relevant_texts) | |