leo-bourrel commited on
Commit
503aad2
·
1 Parent(s): c3b9b9a

feat: Override PGVector with custom source

Browse files
Files changed (1) hide show
  1. custom_pgvector.py +789 -0
custom_pgvector.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import enum
6
+ import logging
7
+ from functools import partial
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Generator,
14
+ Iterable,
15
+ List,
16
+ Optional,
17
+ Tuple,
18
+ Type,
19
+ )
20
+
21
+ import numpy as np
22
+ import sqlalchemy
23
+ from langchain.docstore.document import Document
24
+ from langchain.schema.embeddings import Embeddings
25
+ from langchain.utils import get_from_dict_or_env
26
+ from langchain.vectorstores.base import VectorStore
27
+ from langchain.vectorstores.pgvector import BaseModel
28
+ from langchain.vectorstores.utils import maximal_marginal_relevance
29
+ from pgvector.sqlalchemy import Vector
30
+ from sqlalchemy import delete
31
+ from sqlalchemy.orm import Session, declarative_base, relationship
32
+
33
+ if TYPE_CHECKING:
34
+ from langchain.vectorstores._pgvector_data_models import CollectionStore
35
+
36
+
37
+ class DistanceStrategy(str, enum.Enum):
38
+ """Enumerator of the Distance strategies."""
39
+
40
+ EUCLIDEAN = "l2"
41
+ COSINE = "cosine"
42
+ MAX_INNER_PRODUCT = "inner"
43
+
44
+
45
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
46
+
47
+ Base = declarative_base() # type: Any
48
+
49
+
50
+ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
51
+
52
+
53
+ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
54
+ """Return docs from docs and scores."""
55
+ return [doc for doc, _ in docs_and_scores]
56
+
57
+
58
+ class Article(Base):
59
+ """Embedding store."""
60
+
61
+ __tablename__ = "article"
62
+
63
+ id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, nullable=False)
64
+ title = sqlalchemy.Column(sqlalchemy.String, nullable=True)
65
+ abstract = sqlalchemy.Column(sqlalchemy.String, nullable=True)
66
+ embedding: Vector = sqlalchemy.Column("abstract_embedding", Vector(None))
67
+ doi = sqlalchemy.Column(sqlalchemy.String, nullable=True)
68
+
69
+
70
+ class CustomPGVector(VectorStore):
71
+ """`Postgres`/`PGVector` vector store.
72
+
73
+ To use, you should have the ``pgvector`` python package installed.
74
+
75
+ Args:
76
+ connection_string: Postgres connection string.
77
+ embedding_function: Any embedding function implementing
78
+ `langchain.embeddings.base.Embeddings` interface.
79
+ table_name: The name of the collection to use. (default: langchain)
80
+ NOTE: This is not the name of the table, but the name of the collection.
81
+ The tables will be created when initializing the store (if not exists)
82
+ So, make sure the user has the right permissions to create tables.
83
+ distance_strategy: The distance strategy to use. (default: COSINE)
84
+ pre_delete_collection: If True, will delete the collection if it exists.
85
+ (default: False). Useful for testing.
86
+
87
+ Example:
88
+ .. code-block:: python
89
+
90
+ from langchain.vectorstores import PGVector
91
+ from langchain.embeddings.openai import OpenAIEmbeddings
92
+
93
+ CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3"
94
+ COLLECTION_NAME = "state_of_the_union_test"
95
+ embeddings = OpenAIEmbeddings()
96
+ vectorestore = PGVector.from_documents(
97
+ embedding=embeddings,
98
+ documents=docs,
99
+ table_name=COLLECTION_NAME,
100
+ connection_string=CONNECTION_STRING,
101
+ )
102
+
103
+
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ connection_string: str,
109
+ embedding_function: Embeddings,
110
+ table_name: str,
111
+ column_name: str,
112
+ collection_metadata: Optional[dict] = None,
113
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
114
+ pre_delete_collection: bool = False,
115
+ logger: Optional[logging.Logger] = None,
116
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
117
+ ) -> None:
118
+ self.connection_string = connection_string
119
+ self.embedding_function = embedding_function
120
+ self.table_name = table_name
121
+ self.column_name = column_name
122
+ self.collection_metadata = collection_metadata
123
+ self._distance_strategy = distance_strategy
124
+ self.pre_delete_collection = pre_delete_collection
125
+ self.logger = logger or logging.getLogger(__name__)
126
+ self.override_relevance_score_fn = relevance_score_fn
127
+ self.__post_init__()
128
+
129
+ def __post_init__(
130
+ self,
131
+ ) -> None:
132
+ """
133
+ Initialize the store.
134
+ """
135
+ self._conn = self.connect()
136
+ self.create_vector_extension()
137
+
138
+ self.EmbeddingStore = Article
139
+
140
+ @property
141
+ def embeddings(self) -> Embeddings:
142
+ return self.embedding_function
143
+
144
+ def connect(self) -> sqlalchemy.engine.Connection:
145
+ engine = sqlalchemy.create_engine(self.connection_string)
146
+ conn = engine.connect()
147
+ return conn
148
+
149
+ def create_vector_extension(self) -> None:
150
+ try:
151
+ with Session(self._conn) as session:
152
+ statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
153
+ session.execute(statement)
154
+ session.commit()
155
+ except Exception as e:
156
+ self.logger.exception(e)
157
+
158
+ def drop_tables(self) -> None:
159
+ with self._conn.begin():
160
+ Base.metadata.drop_all(self._conn)
161
+
162
+ @contextlib.contextmanager
163
+ def _make_session(self) -> Generator[Session, None, None]:
164
+ """Create a context manager for the session, bind to _conn string."""
165
+ yield Session(self._conn)
166
+
167
+ def delete(
168
+ self,
169
+ ids: Optional[List[str]] = None,
170
+ **kwargs: Any,
171
+ ) -> None:
172
+ """Delete vectors by ids.
173
+
174
+ Args:
175
+ ids: List of ids to delete.
176
+ """
177
+ with Session(self._conn) as session:
178
+ if ids is not None:
179
+ self.logger.debug(
180
+ "Trying to delete vectors by ids (represented by the model "
181
+ "using the custom ids field)"
182
+ )
183
+ stmt = delete(self.EmbeddingStore).where(
184
+ self.EmbeddingStore.custom_id.in_(ids)
185
+ )
186
+ session.execute(stmt)
187
+ session.commit()
188
+
189
+ @classmethod
190
+ def __from(
191
+ cls,
192
+ texts: List[str],
193
+ embeddings: List[List[float]],
194
+ embedding: Embeddings,
195
+ metadatas: Optional[List[dict]] = None,
196
+ ids: Optional[List[str]] = None,
197
+ table_name: str = "article",
198
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
199
+ connection_string: Optional[str] = None,
200
+ pre_delete_collection: bool = False,
201
+ **kwargs: Any,
202
+ ) -> CustomPGVector:
203
+ if not metadatas:
204
+ metadatas = [{} for _ in texts]
205
+ if connection_string is None:
206
+ connection_string = cls.get_connection_string(kwargs)
207
+
208
+ store = cls(
209
+ connection_string=connection_string,
210
+ table_name=table_name,
211
+ embedding_function=embedding,
212
+ distance_strategy=distance_strategy,
213
+ pre_delete_collection=pre_delete_collection,
214
+ **kwargs,
215
+ )
216
+
217
+ store.add_embeddings(
218
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
219
+ )
220
+
221
+ return store
222
+
223
+ def add_embeddings(
224
+ self,
225
+ texts: Iterable[str],
226
+ embeddings: List[List[float]],
227
+ metadatas: Optional[List[dict]] = None,
228
+ ids: Optional[List[str]] = None,
229
+ **kwargs: Any,
230
+ ) -> List[str]:
231
+ """Add embeddings to the vectorstore.
232
+
233
+ Args:
234
+ texts: Iterable of strings to add to the vectorstore.
235
+ embeddings: List of list of embedding vectors.
236
+ metadatas: List of metadatas associated with the texts.
237
+ kwargs: vectorstore specific parameters
238
+ """
239
+ if not metadatas:
240
+ metadatas = [{} for _ in texts]
241
+
242
+ with Session(self._conn) as session:
243
+ # collection = self.get_collection(session)
244
+ # if not collection:
245
+ # raise ValueError("Collection not found")
246
+ for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
247
+ embedding_store = self.EmbeddingStore(
248
+ embedding=embedding,
249
+ document=text,
250
+ cmetadata=metadata,
251
+ custom_id=id,
252
+ )
253
+ session.add(embedding_store)
254
+ session.commit()
255
+
256
+ return ids
257
+
258
+ def add_texts(
259
+ self,
260
+ texts: Iterable[str],
261
+ metadatas: Optional[List[dict]] = None,
262
+ ids: Optional[List[str]] = None,
263
+ **kwargs: Any,
264
+ ) -> List[str]:
265
+ """Run more texts through the embeddings and add to the vectorstore.
266
+
267
+ Args:
268
+ texts: Iterable of strings to add to the vectorstore.
269
+ metadatas: Optional list of metadatas associated with the texts.
270
+ kwargs: vectorstore specific parameters
271
+
272
+ Returns:
273
+ List of ids from adding the texts into the vectorstore.
274
+ """
275
+ embeddings = self.embedding_function.embed_documents(list(texts))
276
+ return self.add_embeddings(
277
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
278
+ )
279
+
280
+ def similarity_search(
281
+ self,
282
+ query: str,
283
+ k: int = 4,
284
+ filter: Optional[dict] = None,
285
+ **kwargs: Any,
286
+ ) -> List[Document]:
287
+ """Run similarity search with PGVector with distance.
288
+
289
+ Args:
290
+ query (str): Query text to search for.
291
+ k (int): Number of results to return. Defaults to 4.
292
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
293
+
294
+ Returns:
295
+ List of Documents most similar to the query.
296
+ """
297
+ embedding = self.embedding_function.embed_query(text=query)
298
+ return self.similarity_search_by_vector(
299
+ embedding=embedding,
300
+ k=k,
301
+ filter=filter,
302
+ )
303
+
304
+ def similarity_search_with_score(
305
+ self,
306
+ query: str,
307
+ k: int = 4,
308
+ filter: Optional[dict] = None,
309
+ ) -> List[Tuple[Document, float]]:
310
+ """Return docs most similar to query.
311
+
312
+ Args:
313
+ query: Text to look up documents similar to.
314
+ k: Number of Documents to return. Defaults to 4.
315
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
316
+
317
+ Returns:
318
+ List of Documents most similar to the query and score for each.
319
+ """
320
+ embedding = self.embedding_function.embed_query(query)
321
+ docs = self.similarity_search_with_score_by_vector(
322
+ embedding=embedding, k=k, filter=filter
323
+ )
324
+ return docs
325
+
326
+ @property
327
+ def distance_strategy(self) -> Any:
328
+ if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
329
+ return self.EmbeddingStore.embedding.l2_distance
330
+ elif self._distance_strategy == DistanceStrategy.COSINE:
331
+ return self.EmbeddingStore.embedding.cosine_distance
332
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
333
+ return self.EmbeddingStore.embedding.max_inner_product
334
+ else:
335
+ raise ValueError(
336
+ f"Got unexpected value for distance: {self._distance_strategy}. "
337
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
338
+ )
339
+
340
+ def similarity_search_with_score_by_vector(
341
+ self,
342
+ embedding: List[float],
343
+ k: int = 4,
344
+ filter: Optional[dict] = None,
345
+ ) -> List[Tuple[Document, float]]:
346
+ results = self.__query_collection(embedding=embedding, k=k, filter=filter)
347
+
348
+ return self._results_to_docs_and_scores(results)
349
+
350
+ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
351
+ """Return docs and scores from results."""
352
+ docs = [
353
+ (
354
+ Document(
355
+ page_content=result.Article.abstract,
356
+ # metadata={"title": result.Article.title},
357
+ ),
358
+ result.distance if self.embedding_function is not None else None,
359
+ )
360
+ for result in results
361
+ ]
362
+ return docs
363
+
364
+ def __query_collection(
365
+ self,
366
+ embedding: List[float],
367
+ k: int = 4,
368
+ filter: Optional[Dict[str, str]] = None,
369
+ ) -> List[Any]:
370
+ """Query the collection."""
371
+ with Session(self._conn) as session:
372
+ results: List[Any] = (
373
+ session.query(
374
+ self.EmbeddingStore,
375
+ self.distance_strategy(embedding).label("distance"), # type: ignore
376
+ )
377
+ .order_by(sqlalchemy.asc("distance"))
378
+ .limit(k)
379
+ .all()
380
+ )
381
+ print(results)
382
+ return results
383
+
384
+ def similarity_search_by_vector(
385
+ self,
386
+ embedding: List[float],
387
+ k: int = 4,
388
+ filter: Optional[dict] = None,
389
+ **kwargs: Any,
390
+ ) -> List[Document]:
391
+ """Return docs most similar to embedding vector.
392
+
393
+ Args:
394
+ embedding: Embedding to look up documents similar to.
395
+ k: Number of Documents to return. Defaults to 4.
396
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
397
+
398
+ Returns:
399
+ List of Documents most similar to the query vector.
400
+ """
401
+ docs_and_scores = self.similarity_search_with_score_by_vector(
402
+ embedding=embedding, k=k, filter=filter
403
+ )
404
+ return _results_to_docs(docs_and_scores)
405
+
406
+ @classmethod
407
+ def from_texts(
408
+ cls: Type[PGVector],
409
+ texts: List[str],
410
+ embedding: Embeddings,
411
+ metadatas: Optional[List[dict]] = None,
412
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
413
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
414
+ ids: Optional[List[str]] = None,
415
+ pre_delete_collection: bool = False,
416
+ **kwargs: Any,
417
+ ) -> PGVector:
418
+ """
419
+ Return VectorStore initialized from texts and embeddings.
420
+ Postgres connection string is required
421
+ "Either pass it as a parameter
422
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
423
+ """
424
+ embeddings = embedding.embed_documents(list(texts))
425
+
426
+ return cls.__from(
427
+ texts,
428
+ embeddings,
429
+ embedding,
430
+ metadatas=metadatas,
431
+ ids=ids,
432
+ table_name=table_name,
433
+ distance_strategy=distance_strategy,
434
+ pre_delete_collection=pre_delete_collection,
435
+ **kwargs,
436
+ )
437
+
438
+ @classmethod
439
+ def from_embeddings(
440
+ cls,
441
+ text_embeddings: List[Tuple[str, List[float]]],
442
+ embedding: Embeddings,
443
+ metadatas: Optional[List[dict]] = None,
444
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
445
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
446
+ ids: Optional[List[str]] = None,
447
+ pre_delete_collection: bool = False,
448
+ **kwargs: Any,
449
+ ) -> PGVector:
450
+ """Construct PGVector wrapper from raw documents and pre-
451
+ generated embeddings.
452
+
453
+ Return VectorStore initialized from documents and embeddings.
454
+ Postgres connection string is required
455
+ "Either pass it as a parameter
456
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
457
+
458
+ Example:
459
+ .. code-block:: python
460
+
461
+ from langchain.vectorstores import PGVector
462
+ from langchain.embeddings import OpenAIEmbeddings
463
+ embeddings = OpenAIEmbeddings()
464
+ text_embeddings = embeddings.embed_documents(texts)
465
+ text_embedding_pairs = list(zip(texts, text_embeddings))
466
+ faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
467
+ """
468
+ texts = [t[0] for t in text_embeddings]
469
+ embeddings = [t[1] for t in text_embeddings]
470
+
471
+ return cls.__from(
472
+ texts,
473
+ embeddings,
474
+ embedding,
475
+ metadatas=metadatas,
476
+ ids=ids,
477
+ table_name=table_name,
478
+ distance_strategy=distance_strategy,
479
+ pre_delete_collection=pre_delete_collection,
480
+ **kwargs,
481
+ )
482
+
483
+ @classmethod
484
+ def from_existing_index(
485
+ cls: Type[PGVector],
486
+ embedding: Embeddings,
487
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
488
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
489
+ pre_delete_collection: bool = False,
490
+ **kwargs: Any,
491
+ ) -> PGVector:
492
+ """
493
+ Get intsance of an existing PGVector store.This method will
494
+ return the instance of the store without inserting any new
495
+ embeddings
496
+ """
497
+
498
+ connection_string = cls.get_connection_string(kwargs)
499
+
500
+ store = cls(
501
+ connection_string=connection_string,
502
+ table_name=table_name,
503
+ embedding_function=embedding,
504
+ distance_strategy=distance_strategy,
505
+ pre_delete_collection=pre_delete_collection,
506
+ )
507
+
508
+ return store
509
+
510
+ @classmethod
511
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
512
+ connection_string: str = get_from_dict_or_env(
513
+ data=kwargs,
514
+ key="connection_string",
515
+ env_key="PGVECTOR_CONNECTION_STRING",
516
+ )
517
+
518
+ if not connection_string:
519
+ raise ValueError(
520
+ "Postgres connection string is required"
521
+ "Either pass it as a parameter"
522
+ "or set the PGVECTOR_CONNECTION_STRING environment variable."
523
+ )
524
+
525
+ return connection_string
526
+
527
+ @classmethod
528
+ def from_documents(
529
+ cls: Type[CustomPGVector],
530
+ documents: List[Document],
531
+ embedding: Embeddings,
532
+ table_name: str = "article",
533
+ column_name: str = "embeding",
534
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
535
+ ids: Optional[List[str]] = None,
536
+ pre_delete_collection: bool = False,
537
+ **kwargs: Any,
538
+ ) -> CustomPGVector:
539
+ """
540
+ Return VectorStore initialized from documents and embeddings.
541
+ Postgres connection string is required
542
+ "Either pass it as a parameter
543
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
544
+ """
545
+
546
+ texts = [d.page_content for d in documents]
547
+ metadatas = [d.metadata for d in documents]
548
+ connection_string = cls.get_connection_string(kwargs)
549
+
550
+ kwargs["connection_string"] = connection_string
551
+
552
+ return cls.from_texts(
553
+ texts=texts,
554
+ pre_delete_collection=pre_delete_collection,
555
+ embedding=embedding,
556
+ distance_strategy=distance_strategy,
557
+ metadatas=metadatas,
558
+ ids=ids,
559
+ table_name=table_name,
560
+ column_name=column_name,
561
+ **kwargs,
562
+ )
563
+
564
+ @classmethod
565
+ def connection_string_from_db_params(
566
+ cls,
567
+ driver: str,
568
+ host: str,
569
+ port: int,
570
+ database: str,
571
+ user: str,
572
+ password: str,
573
+ ) -> str:
574
+ """Return connection string from database parameters."""
575
+ return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
576
+
577
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
578
+ """
579
+ The 'correct' relevance function
580
+ may differ depending on a few things, including:
581
+ - the distance / similarity metric used by the VectorStore
582
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
583
+ - embedding dimensionality
584
+ - etc.
585
+ """
586
+ if self.override_relevance_score_fn is not None:
587
+ return self.override_relevance_score_fn
588
+
589
+ # Default strategy is to rely on distance strategy provided
590
+ # in vectorstore constructor
591
+ if self._distance_strategy == DistanceStrategy.COSINE:
592
+ return self._cosine_relevance_score_fn
593
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
594
+ return self._euclidean_relevance_score_fn
595
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
596
+ return self._max_inner_product_relevance_score_fn
597
+ else:
598
+ raise ValueError(
599
+ "No supported normalization function"
600
+ f" for distance_strategy of {self._distance_strategy}."
601
+ "Consider providing relevance_score_fn to PGVector constructor."
602
+ )
603
+
604
+ def max_marginal_relevance_search_with_score_by_vector(
605
+ self,
606
+ embedding: List[float],
607
+ k: int = 4,
608
+ fetch_k: int = 20,
609
+ lambda_mult: float = 0.5,
610
+ filter: Optional[Dict[str, str]] = None,
611
+ **kwargs: Any,
612
+ ) -> List[Tuple[Document, float]]:
613
+ """Return docs selected using the maximal marginal relevance with score
614
+ to embedding vector.
615
+
616
+ Maximal marginal relevance optimizes for similarity to query AND diversity
617
+ among selected documents.
618
+
619
+ Args:
620
+ embedding: Embedding to look up documents similar to.
621
+ k (int): Number of Documents to return. Defaults to 4.
622
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
623
+ Defaults to 20.
624
+ lambda_mult (float): Number between 0 and 1 that determines the degree
625
+ of diversity among the results with 0 corresponding
626
+ to maximum diversity and 1 to minimum diversity.
627
+ Defaults to 0.5.
628
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
629
+
630
+ Returns:
631
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
632
+ relevance to the query and score for each.
633
+ """
634
+ results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
635
+
636
+ embedding_list = [result.EmbeddingStore.embedding for result in results]
637
+
638
+ mmr_selected = maximal_marginal_relevance(
639
+ np.array(embedding, dtype=np.float32),
640
+ embedding_list,
641
+ k=k,
642
+ lambda_mult=lambda_mult,
643
+ )
644
+
645
+ candidates = self._results_to_docs_and_scores(results)
646
+
647
+ return [r for i, r in enumerate(candidates) if i in mmr_selected]
648
+
649
+ def max_marginal_relevance_search(
650
+ self,
651
+ query: str,
652
+ k: int = 4,
653
+ fetch_k: int = 20,
654
+ lambda_mult: float = 0.5,
655
+ filter: Optional[Dict[str, str]] = None,
656
+ **kwargs: Any,
657
+ ) -> List[Document]:
658
+ """Return docs selected using the maximal marginal relevance.
659
+
660
+ Maximal marginal relevance optimizes for similarity to query AND diversity
661
+ among selected documents.
662
+
663
+ Args:
664
+ query (str): Text to look up documents similar to.
665
+ k (int): Number of Documents to return. Defaults to 4.
666
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
667
+ Defaults to 20.
668
+ lambda_mult (float): Number between 0 and 1 that determines the degree
669
+ of diversity among the results with 0 corresponding
670
+ to maximum diversity and 1 to minimum diversity.
671
+ Defaults to 0.5.
672
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
673
+
674
+ Returns:
675
+ List[Document]: List of Documents selected by maximal marginal relevance.
676
+ """
677
+ embedding = self.embedding_function.embed_query(query)
678
+ return self.max_marginal_relevance_search_by_vector(
679
+ embedding,
680
+ k=k,
681
+ fetch_k=fetch_k,
682
+ lambda_mult=lambda_mult,
683
+ **kwargs,
684
+ )
685
+
686
+ def max_marginal_relevance_search_with_score(
687
+ self,
688
+ query: str,
689
+ k: int = 4,
690
+ fetch_k: int = 20,
691
+ lambda_mult: float = 0.5,
692
+ filter: Optional[dict] = None,
693
+ **kwargs: Any,
694
+ ) -> List[Tuple[Document, float]]:
695
+ """Return docs selected using the maximal marginal relevance with score.
696
+
697
+ Maximal marginal relevance optimizes for similarity to query AND diversity
698
+ among selected documents.
699
+
700
+ Args:
701
+ query (str): Text to look up documents similar to.
702
+ k (int): Number of Documents to return. Defaults to 4.
703
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
704
+ Defaults to 20.
705
+ lambda_mult (float): Number between 0 and 1 that determines the degree
706
+ of diversity among the results with 0 corresponding
707
+ to maximum diversity and 1 to minimum diversity.
708
+ Defaults to 0.5.
709
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
710
+
711
+ Returns:
712
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
713
+ relevance to the query and score for each.
714
+ """
715
+ embedding = self.embedding_function.embed_query(query)
716
+ docs = self.max_marginal_relevance_search_with_score_by_vector(
717
+ embedding=embedding,
718
+ k=k,
719
+ fetch_k=fetch_k,
720
+ lambda_mult=lambda_mult,
721
+ filter=filter,
722
+ **kwargs,
723
+ )
724
+ return docs
725
+
726
+ def max_marginal_relevance_search_by_vector(
727
+ self,
728
+ embedding: List[float],
729
+ k: int = 4,
730
+ fetch_k: int = 20,
731
+ lambda_mult: float = 0.5,
732
+ filter: Optional[Dict[str, str]] = None,
733
+ **kwargs: Any,
734
+ ) -> List[Document]:
735
+ """Return docs selected using the maximal marginal relevance
736
+ to embedding vector.
737
+
738
+ Maximal marginal relevance optimizes for similarity to query AND diversity
739
+ among selected documents.
740
+
741
+ Args:
742
+ embedding (str): Text to look up documents similar to.
743
+ k (int): Number of Documents to return. Defaults to 4.
744
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
745
+ Defaults to 20.
746
+ lambda_mult (float): Number between 0 and 1 that determines the degree
747
+ of diversity among the results with 0 corresponding
748
+ to maximum diversity and 1 to minimum diversity.
749
+ Defaults to 0.5.
750
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
751
+
752
+ Returns:
753
+ List[Document]: List of Documents selected by maximal marginal relevance.
754
+ """
755
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
756
+ embedding,
757
+ k=k,
758
+ fetch_k=fetch_k,
759
+ lambda_mult=lambda_mult,
760
+ filter=filter,
761
+ **kwargs,
762
+ )
763
+
764
+ return _results_to_docs(docs_and_scores)
765
+
766
+ async def amax_marginal_relevance_search_by_vector(
767
+ self,
768
+ embedding: List[float],
769
+ k: int = 4,
770
+ fetch_k: int = 20,
771
+ lambda_mult: float = 0.5,
772
+ filter: Optional[Dict[str, str]] = None,
773
+ **kwargs: Any,
774
+ ) -> List[Document]:
775
+ """Return docs selected using the maximal marginal relevance."""
776
+
777
+ # This is a temporary workaround to make the similarity search
778
+ # asynchronous. The proper solution is to make the similarity search
779
+ # asynchronous in the vector store implementations.
780
+ func = partial(
781
+ self.max_marginal_relevance_search_by_vector,
782
+ embedding,
783
+ k=k,
784
+ fetch_k=fetch_k,
785
+ lambda_mult=lambda_mult,
786
+ filter=filter,
787
+ **kwargs,
788
+ )
789
+ return await asyncio.get_event_loop().run_in_executor(None, func)