leo-bourrel commited on
Commit
63dc793
·
2 Parent(s): 389b0ad 550e87b

Merge branch 'postgres_retrieval' into main

Browse files
Files changed (2) hide show
  1. app.py +29 -8
  2. custom_pgvector.py +789 -0
app.py CHANGED
@@ -3,12 +3,16 @@ import os
3
  import streamlit as st
4
  import streamlit.components.v1 as components
5
  from css import load_css
 
6
  from langchain import OpenAI
7
  from langchain.callbacks import get_openai_callback
8
- from langchain.chains import ConversationChain
9
- from langchain.chains.conversation.memory import ConversationSummaryMemory
 
10
  from message import Message
11
 
 
 
12
 
13
  def initialize_session_state():
14
  if "history" not in st.session_state:
@@ -16,21 +20,38 @@ def initialize_session_state():
16
  if "token_count" not in st.session_state:
17
  st.session_state.token_count = 0
18
  if "conversation" not in st.session_state:
 
 
 
 
 
 
 
 
 
 
 
19
  llm = OpenAI(
20
  temperature=0,
21
  openai_api_key=os.environ["OPENAI_API_KEY"],
22
- model_name="text-davinci-003",
23
  )
24
- st.session_state.conversation = ConversationChain(
25
- llm=llm,
26
- memory=ConversationSummaryMemory(llm=llm),
 
27
  )
28
 
29
 
30
  def on_click_callback():
31
  with get_openai_callback() as cb:
32
  human_prompt = st.session_state.human_prompt
33
- llm_response = st.session_state.conversation.run(human_prompt)
 
 
 
 
 
34
  st.session_state.history.append(Message("human", human_prompt))
35
  st.session_state.history.append(Message("ai", llm_response))
36
  st.session_state.token_count += cb.total_tokens
@@ -84,7 +105,7 @@ information_placeholder.caption(
84
  f"""
85
  Used {st.session_state.token_count} tokens \n
86
  Debug Langchain conversation:
87
- {st.session_state.conversation.memory.buffer}
88
  """
89
  )
90
 
 
3
  import streamlit as st
4
  import streamlit.components.v1 as components
5
  from css import load_css
6
+ from custom_pgvector import CustomPGVector
7
  from langchain import OpenAI
8
  from langchain.callbacks import get_openai_callback
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain.chains.conversation.memory import ConversationBufferMemory
11
+ from langchain.embeddings import GPT4AllEmbeddings
12
  from message import Message
13
 
14
+ CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
15
+
16
 
17
  def initialize_session_state():
18
  if "history" not in st.session_state:
 
20
  if "token_count" not in st.session_state:
21
  st.session_state.token_count = 0
22
  if "conversation" not in st.session_state:
23
+ embeddings = GPT4AllEmbeddings()
24
+
25
+ db = CustomPGVector(
26
+ embedding_function=embeddings,
27
+ table_name="article",
28
+ column_name="abstract_embedding",
29
+ connection_string=CONNECTION_STRING,
30
+ )
31
+
32
+ retriever = db.as_retriever()
33
+
34
  llm = OpenAI(
35
  temperature=0,
36
  openai_api_key=os.environ["OPENAI_API_KEY"],
37
+ model="text-davinci-003",
38
  )
39
+
40
+ st.session_state.memory = ConversationBufferMemory()
41
+ st.session_state.conversation = ConversationalRetrievalChain.from_llm(
42
+ llm=llm, retriever=retriever, verbose=True
43
  )
44
 
45
 
46
  def on_click_callback():
47
  with get_openai_callback() as cb:
48
  human_prompt = st.session_state.human_prompt
49
+ llm_response = st.session_state.conversation.run(
50
+ {
51
+ "question": human_prompt,
52
+ "chat_history": st.session_state.memory.buffer,
53
+ }
54
+ )
55
  st.session_state.history.append(Message("human", human_prompt))
56
  st.session_state.history.append(Message("ai", llm_response))
57
  st.session_state.token_count += cb.total_tokens
 
105
  f"""
106
  Used {st.session_state.token_count} tokens \n
107
  Debug Langchain conversation:
108
+ {st.session_state.memory.buffer}
109
  """
110
  )
111
 
custom_pgvector.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import enum
6
+ import logging
7
+ from functools import partial
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ Generator,
13
+ Iterable,
14
+ List,
15
+ Optional,
16
+ Tuple,
17
+ Type,
18
+ )
19
+
20
+ import numpy as np
21
+ import sqlalchemy
22
+ from langchain.docstore.document import Document
23
+ from langchain.schema.embeddings import Embeddings
24
+ from langchain.utils import get_from_dict_or_env
25
+ from langchain.vectorstores.base import VectorStore
26
+ from langchain.vectorstores.pgvector import BaseModel
27
+ from langchain.vectorstores.utils import maximal_marginal_relevance
28
+ from pgvector.sqlalchemy import Vector
29
+ from sqlalchemy import delete
30
+ from sqlalchemy.orm import Session, declarative_base, relationship
31
+
32
+
33
+ class DistanceStrategy(str, enum.Enum):
34
+ """Enumerator of the Distance strategies."""
35
+
36
+ EUCLIDEAN = "l2"
37
+ COSINE = "cosine"
38
+ MAX_INNER_PRODUCT = "inner"
39
+
40
+
41
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
42
+
43
+ Base = declarative_base() # type: Any
44
+
45
+
46
+ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
47
+
48
+
49
+ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
50
+ """Return docs from docs and scores."""
51
+ return [doc for doc, _ in docs_and_scores]
52
+
53
+
54
+ class Article(Base):
55
+ """Embedding store."""
56
+
57
+ __tablename__ = "article"
58
+
59
+ id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, nullable=False)
60
+ title = sqlalchemy.Column(sqlalchemy.String, nullable=True)
61
+ abstract = sqlalchemy.Column(sqlalchemy.String, nullable=True)
62
+ embedding: Vector = sqlalchemy.Column("abstract_embedding", Vector(None))
63
+ doi = sqlalchemy.Column(sqlalchemy.String, nullable=True)
64
+
65
+
66
+ class CustomPGVector(VectorStore):
67
+ """`Postgres`/`PGVector` vector store.
68
+
69
+ To use, you should have the ``pgvector`` python package installed.
70
+
71
+ Args:
72
+ connection_string: Postgres connection string.
73
+ embedding_function: Any embedding function implementing
74
+ `langchain.embeddings.base.Embeddings` interface.
75
+ table_name: The name of the collection to use. (default: langchain)
76
+ NOTE: This is not the name of the table, but the name of the collection.
77
+ The tables will be created when initializing the store (if not exists)
78
+ So, make sure the user has the right permissions to create tables.
79
+ distance_strategy: The distance strategy to use. (default: COSINE)
80
+ pre_delete_collection: If True, will delete the collection if it exists.
81
+ (default: False). Useful for testing.
82
+
83
+ Example:
84
+ .. code-block:: python
85
+
86
+ from langchain.vectorstores import PGVector
87
+ from langchain.embeddings.openai import OpenAIEmbeddings
88
+
89
+ CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3"
90
+ COLLECTION_NAME = "state_of_the_union_test"
91
+ embeddings = OpenAIEmbeddings()
92
+ vectorestore = PGVector.from_documents(
93
+ embedding=embeddings,
94
+ documents=docs,
95
+ table_name=COLLECTION_NAME,
96
+ connection_string=CONNECTION_STRING,
97
+ )
98
+
99
+
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ connection_string: str,
105
+ embedding_function: Embeddings,
106
+ table_name: str,
107
+ column_name: str,
108
+ collection_metadata: Optional[dict] = None,
109
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
110
+ pre_delete_collection: bool = False,
111
+ logger: Optional[logging.Logger] = None,
112
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
113
+ ) -> None:
114
+ self.connection_string = connection_string
115
+ self.embedding_function = embedding_function
116
+ self.table_name = table_name
117
+ self.column_name = column_name
118
+ self.collection_metadata = collection_metadata
119
+ self._distance_strategy = distance_strategy
120
+ self.pre_delete_collection = pre_delete_collection
121
+ self.logger = logger or logging.getLogger(__name__)
122
+ self.override_relevance_score_fn = relevance_score_fn
123
+ self.__post_init__()
124
+
125
+ def __post_init__(
126
+ self,
127
+ ) -> None:
128
+ """
129
+ Initialize the store.
130
+ """
131
+ self._conn = self.connect()
132
+ self.create_vector_extension()
133
+
134
+ self.EmbeddingStore = Article
135
+
136
+ @property
137
+ def embeddings(self) -> Embeddings:
138
+ return self.embedding_function
139
+
140
+ def connect(self) -> sqlalchemy.engine.Connection:
141
+ engine = sqlalchemy.create_engine(self.connection_string)
142
+ conn = engine.connect()
143
+ return conn
144
+
145
+ def create_vector_extension(self) -> None:
146
+ try:
147
+ with Session(self._conn) as session:
148
+ statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
149
+ session.execute(statement)
150
+ session.commit()
151
+ except Exception as e:
152
+ self.logger.exception(e)
153
+
154
+ def drop_tables(self) -> None:
155
+ with self._conn.begin():
156
+ Base.metadata.drop_all(self._conn)
157
+
158
+ @contextlib.contextmanager
159
+ def _make_session(self) -> Generator[Session, None, None]:
160
+ """Create a context manager for the session, bind to _conn string."""
161
+ yield Session(self._conn)
162
+
163
+ def delete(
164
+ self,
165
+ ids: Optional[List[str]] = None,
166
+ **kwargs: Any,
167
+ ) -> None:
168
+ """Delete vectors by ids.
169
+
170
+ Args:
171
+ ids: List of ids to delete.
172
+ """
173
+ with Session(self._conn) as session:
174
+ if ids is not None:
175
+ self.logger.debug(
176
+ "Trying to delete vectors by ids (represented by the model "
177
+ "using the custom ids field)"
178
+ )
179
+ stmt = delete(self.EmbeddingStore).where(
180
+ self.EmbeddingStore.custom_id.in_(ids)
181
+ )
182
+ session.execute(stmt)
183
+ session.commit()
184
+
185
+ @classmethod
186
+ def __from(
187
+ cls,
188
+ texts: List[str],
189
+ embeddings: List[List[float]],
190
+ embedding: Embeddings,
191
+ metadatas: Optional[List[dict]] = None,
192
+ ids: Optional[List[str]] = None,
193
+ table_name: str = "article",
194
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
195
+ connection_string: Optional[str] = None,
196
+ pre_delete_collection: bool = False,
197
+ **kwargs: Any,
198
+ ) -> CustomPGVector:
199
+ if not metadatas:
200
+ metadatas = [{} for _ in texts]
201
+ if connection_string is None:
202
+ connection_string = cls.get_connection_string(kwargs)
203
+
204
+ store = cls(
205
+ connection_string=connection_string,
206
+ table_name=table_name,
207
+ embedding_function=embedding,
208
+ distance_strategy=distance_strategy,
209
+ pre_delete_collection=pre_delete_collection,
210
+ **kwargs,
211
+ )
212
+
213
+ store.add_embeddings(
214
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
215
+ )
216
+
217
+ return store
218
+
219
+ def add_embeddings(
220
+ self,
221
+ texts: Iterable[str],
222
+ embeddings: List[List[float]],
223
+ metadatas: Optional[List[dict]] = None,
224
+ ids: Optional[List[str]] = None,
225
+ **kwargs: Any,
226
+ ) -> List[str]:
227
+ """Add embeddings to the vectorstore.
228
+
229
+ Args:
230
+ texts: Iterable of strings to add to the vectorstore.
231
+ embeddings: List of list of embedding vectors.
232
+ metadatas: List of metadatas associated with the texts.
233
+ kwargs: vectorstore specific parameters
234
+ """
235
+ if not metadatas:
236
+ metadatas = [{} for _ in texts]
237
+
238
+ with Session(self._conn) as session:
239
+ # collection = self.get_collection(session)
240
+ # if not collection:
241
+ # raise ValueError("Collection not found")
242
+ for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
243
+ embedding_store = self.EmbeddingStore(
244
+ embedding=embedding,
245
+ document=text,
246
+ cmetadata=metadata,
247
+ custom_id=id,
248
+ )
249
+ session.add(embedding_store)
250
+ session.commit()
251
+
252
+ return ids
253
+
254
+ def add_texts(
255
+ self,
256
+ texts: Iterable[str],
257
+ metadatas: Optional[List[dict]] = None,
258
+ ids: Optional[List[str]] = None,
259
+ **kwargs: Any,
260
+ ) -> List[str]:
261
+ """Run more texts through the embeddings and add to the vectorstore.
262
+
263
+ Args:
264
+ texts: Iterable of strings to add to the vectorstore.
265
+ metadatas: Optional list of metadatas associated with the texts.
266
+ kwargs: vectorstore specific parameters
267
+
268
+ Returns:
269
+ List of ids from adding the texts into the vectorstore.
270
+ """
271
+ embeddings = self.embedding_function.embed_documents(list(texts))
272
+ return self.add_embeddings(
273
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
274
+ )
275
+
276
+ def similarity_search(
277
+ self,
278
+ query: str,
279
+ k: int = 4,
280
+ filter: Optional[dict] = None,
281
+ **kwargs: Any,
282
+ ) -> List[Document]:
283
+ """Run similarity search with PGVector with distance.
284
+
285
+ Args:
286
+ query (str): Query text to search for.
287
+ k (int): Number of results to return. Defaults to 4.
288
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
289
+
290
+ Returns:
291
+ List of Documents most similar to the query.
292
+ """
293
+ embedding = self.embedding_function.embed_query(text=query)
294
+ return self.similarity_search_by_vector(
295
+ embedding=embedding,
296
+ k=k,
297
+ filter=filter,
298
+ )
299
+
300
+ def similarity_search_with_score(
301
+ self,
302
+ query: str,
303
+ k: int = 4,
304
+ filter: Optional[dict] = None,
305
+ ) -> List[Tuple[Document, float]]:
306
+ """Return docs most similar to query.
307
+
308
+ Args:
309
+ query: Text to look up documents similar to.
310
+ k: Number of Documents to return. Defaults to 4.
311
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
312
+
313
+ Returns:
314
+ List of Documents most similar to the query and score for each.
315
+ """
316
+ embedding = self.embedding_function.embed_query(query)
317
+ docs = self.similarity_search_with_score_by_vector(
318
+ embedding=embedding, k=k, filter=filter
319
+ )
320
+ return docs
321
+
322
+ @property
323
+ def distance_strategy(self) -> Any:
324
+ if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
325
+ return self.EmbeddingStore.embedding.l2_distance
326
+ elif self._distance_strategy == DistanceStrategy.COSINE:
327
+ return self.EmbeddingStore.embedding.cosine_distance
328
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
329
+ return self.EmbeddingStore.embedding.max_inner_product
330
+ else:
331
+ raise ValueError(
332
+ f"Got unexpected value for distance: {self._distance_strategy}. "
333
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
334
+ )
335
+
336
+ def similarity_search_with_score_by_vector(
337
+ self,
338
+ embedding: List[float],
339
+ k: int = 4,
340
+ filter: Optional[dict] = None,
341
+ ) -> List[Tuple[Document, float]]:
342
+ results = self.__query_collection(embedding=embedding, k=k, filter=filter)
343
+
344
+ return self._results_to_docs_and_scores(results)
345
+
346
+ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
347
+ """Return docs and scores from results."""
348
+ docs = [
349
+ (
350
+ Document(
351
+ page_content=result.Article.abstract,
352
+ metadata={
353
+ "id": result.Article.id,
354
+ "title": result.Article.title,
355
+ "doi": result.Article.doi,
356
+ },
357
+ ),
358
+ result.distance if self.embedding_function is not None else None,
359
+ )
360
+ for result in results
361
+ ]
362
+ return docs
363
+
364
+ def __query_collection(
365
+ self,
366
+ embedding: List[float],
367
+ k: int = 4,
368
+ filter: Optional[Dict[str, str]] = None,
369
+ ) -> List[Any]:
370
+ """Query the collection."""
371
+ with Session(self._conn) as session:
372
+ results: List[Any] = (
373
+ session.query(
374
+ self.EmbeddingStore,
375
+ self.distance_strategy(embedding).label("distance"), # type: ignore
376
+ )
377
+ .order_by(sqlalchemy.asc("distance"))
378
+ .limit(k)
379
+ .all()
380
+ )
381
+ print(results)
382
+ return results
383
+
384
+ def similarity_search_by_vector(
385
+ self,
386
+ embedding: List[float],
387
+ k: int = 4,
388
+ filter: Optional[dict] = None,
389
+ **kwargs: Any,
390
+ ) -> List[Document]:
391
+ """Return docs most similar to embedding vector.
392
+
393
+ Args:
394
+ embedding: Embedding to look up documents similar to.
395
+ k: Number of Documents to return. Defaults to 4.
396
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
397
+
398
+ Returns:
399
+ List of Documents most similar to the query vector.
400
+ """
401
+ docs_and_scores = self.similarity_search_with_score_by_vector(
402
+ embedding=embedding, k=k, filter=filter
403
+ )
404
+ return _results_to_docs(docs_and_scores)
405
+
406
+ @classmethod
407
+ def from_texts(
408
+ cls: Type[PGVector],
409
+ texts: List[str],
410
+ embedding: Embeddings,
411
+ metadatas: Optional[List[dict]] = None,
412
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
413
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
414
+ ids: Optional[List[str]] = None,
415
+ pre_delete_collection: bool = False,
416
+ **kwargs: Any,
417
+ ) -> PGVector:
418
+ """
419
+ Return VectorStore initialized from texts and embeddings.
420
+ Postgres connection string is required
421
+ "Either pass it as a parameter
422
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
423
+ """
424
+ embeddings = embedding.embed_documents(list(texts))
425
+
426
+ return cls.__from(
427
+ texts,
428
+ embeddings,
429
+ embedding,
430
+ metadatas=metadatas,
431
+ ids=ids,
432
+ table_name=table_name,
433
+ distance_strategy=distance_strategy,
434
+ pre_delete_collection=pre_delete_collection,
435
+ **kwargs,
436
+ )
437
+
438
+ @classmethod
439
+ def from_embeddings(
440
+ cls,
441
+ text_embeddings: List[Tuple[str, List[float]]],
442
+ embedding: Embeddings,
443
+ metadatas: Optional[List[dict]] = None,
444
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
445
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
446
+ ids: Optional[List[str]] = None,
447
+ pre_delete_collection: bool = False,
448
+ **kwargs: Any,
449
+ ) -> PGVector:
450
+ """Construct PGVector wrapper from raw documents and pre-
451
+ generated embeddings.
452
+
453
+ Return VectorStore initialized from documents and embeddings.
454
+ Postgres connection string is required
455
+ "Either pass it as a parameter
456
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
457
+
458
+ Example:
459
+ .. code-block:: python
460
+
461
+ from langchain.vectorstores import PGVector
462
+ from langchain.embeddings import OpenAIEmbeddings
463
+ embeddings = OpenAIEmbeddings()
464
+ text_embeddings = embeddings.embed_documents(texts)
465
+ text_embedding_pairs = list(zip(texts, text_embeddings))
466
+ faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
467
+ """
468
+ texts = [t[0] for t in text_embeddings]
469
+ embeddings = [t[1] for t in text_embeddings]
470
+
471
+ return cls.__from(
472
+ texts,
473
+ embeddings,
474
+ embedding,
475
+ metadatas=metadatas,
476
+ ids=ids,
477
+ table_name=table_name,
478
+ distance_strategy=distance_strategy,
479
+ pre_delete_collection=pre_delete_collection,
480
+ **kwargs,
481
+ )
482
+
483
+ @classmethod
484
+ def from_existing_index(
485
+ cls: Type[PGVector],
486
+ embedding: Embeddings,
487
+ table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
488
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
489
+ pre_delete_collection: bool = False,
490
+ **kwargs: Any,
491
+ ) -> PGVector:
492
+ """
493
+ Get intsance of an existing PGVector store.This method will
494
+ return the instance of the store without inserting any new
495
+ embeddings
496
+ """
497
+
498
+ connection_string = cls.get_connection_string(kwargs)
499
+
500
+ store = cls(
501
+ connection_string=connection_string,
502
+ table_name=table_name,
503
+ embedding_function=embedding,
504
+ distance_strategy=distance_strategy,
505
+ pre_delete_collection=pre_delete_collection,
506
+ )
507
+
508
+ return store
509
+
510
+ @classmethod
511
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
512
+ connection_string: str = get_from_dict_or_env(
513
+ data=kwargs,
514
+ key="connection_string",
515
+ env_key="PGVECTOR_CONNECTION_STRING",
516
+ )
517
+
518
+ if not connection_string:
519
+ raise ValueError(
520
+ "Postgres connection string is required"
521
+ "Either pass it as a parameter"
522
+ "or set the PGVECTOR_CONNECTION_STRING environment variable."
523
+ )
524
+
525
+ return connection_string
526
+
527
+ @classmethod
528
+ def from_documents(
529
+ cls: Type[CustomPGVector],
530
+ documents: List[Document],
531
+ embedding: Embeddings,
532
+ table_name: str = "article",
533
+ column_name: str = "embeding",
534
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
535
+ ids: Optional[List[str]] = None,
536
+ pre_delete_collection: bool = False,
537
+ **kwargs: Any,
538
+ ) -> CustomPGVector:
539
+ """
540
+ Return VectorStore initialized from documents and embeddings.
541
+ Postgres connection string is required
542
+ "Either pass it as a parameter
543
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
544
+ """
545
+
546
+ texts = [d.page_content for d in documents]
547
+ metadatas = [d.metadata for d in documents]
548
+ connection_string = cls.get_connection_string(kwargs)
549
+
550
+ kwargs["connection_string"] = connection_string
551
+
552
+ return cls.from_texts(
553
+ texts=texts,
554
+ pre_delete_collection=pre_delete_collection,
555
+ embedding=embedding,
556
+ distance_strategy=distance_strategy,
557
+ metadatas=metadatas,
558
+ ids=ids,
559
+ table_name=table_name,
560
+ column_name=column_name,
561
+ **kwargs,
562
+ )
563
+
564
+ @classmethod
565
+ def connection_string_from_db_params(
566
+ cls,
567
+ driver: str,
568
+ host: str,
569
+ port: int,
570
+ database: str,
571
+ user: str,
572
+ password: str,
573
+ ) -> str:
574
+ """Return connection string from database parameters."""
575
+ return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
576
+
577
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
578
+ """
579
+ The 'correct' relevance function
580
+ may differ depending on a few things, including:
581
+ - the distance / similarity metric used by the VectorStore
582
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
583
+ - embedding dimensionality
584
+ - etc.
585
+ """
586
+ if self.override_relevance_score_fn is not None:
587
+ return self.override_relevance_score_fn
588
+
589
+ # Default strategy is to rely on distance strategy provided
590
+ # in vectorstore constructor
591
+ if self._distance_strategy == DistanceStrategy.COSINE:
592
+ return self._cosine_relevance_score_fn
593
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
594
+ return self._euclidean_relevance_score_fn
595
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
596
+ return self._max_inner_product_relevance_score_fn
597
+ else:
598
+ raise ValueError(
599
+ "No supported normalization function"
600
+ f" for distance_strategy of {self._distance_strategy}."
601
+ "Consider providing relevance_score_fn to PGVector constructor."
602
+ )
603
+
604
+ def max_marginal_relevance_search_with_score_by_vector(
605
+ self,
606
+ embedding: List[float],
607
+ k: int = 4,
608
+ fetch_k: int = 20,
609
+ lambda_mult: float = 0.5,
610
+ filter: Optional[Dict[str, str]] = None,
611
+ **kwargs: Any,
612
+ ) -> List[Tuple[Document, float]]:
613
+ """Return docs selected using the maximal marginal relevance with score
614
+ to embedding vector.
615
+
616
+ Maximal marginal relevance optimizes for similarity to query AND diversity
617
+ among selected documents.
618
+
619
+ Args:
620
+ embedding: Embedding to look up documents similar to.
621
+ k (int): Number of Documents to return. Defaults to 4.
622
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
623
+ Defaults to 20.
624
+ lambda_mult (float): Number between 0 and 1 that determines the degree
625
+ of diversity among the results with 0 corresponding
626
+ to maximum diversity and 1 to minimum diversity.
627
+ Defaults to 0.5.
628
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
629
+
630
+ Returns:
631
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
632
+ relevance to the query and score for each.
633
+ """
634
+ results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
635
+
636
+ embedding_list = [result.EmbeddingStore.embedding for result in results]
637
+
638
+ mmr_selected = maximal_marginal_relevance(
639
+ np.array(embedding, dtype=np.float32),
640
+ embedding_list,
641
+ k=k,
642
+ lambda_mult=lambda_mult,
643
+ )
644
+
645
+ candidates = self._results_to_docs_and_scores(results)
646
+
647
+ return [r for i, r in enumerate(candidates) if i in mmr_selected]
648
+
649
+ def max_marginal_relevance_search(
650
+ self,
651
+ query: str,
652
+ k: int = 4,
653
+ fetch_k: int = 20,
654
+ lambda_mult: float = 0.5,
655
+ filter: Optional[Dict[str, str]] = None,
656
+ **kwargs: Any,
657
+ ) -> List[Document]:
658
+ """Return docs selected using the maximal marginal relevance.
659
+
660
+ Maximal marginal relevance optimizes for similarity to query AND diversity
661
+ among selected documents.
662
+
663
+ Args:
664
+ query (str): Text to look up documents similar to.
665
+ k (int): Number of Documents to return. Defaults to 4.
666
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
667
+ Defaults to 20.
668
+ lambda_mult (float): Number between 0 and 1 that determines the degree
669
+ of diversity among the results with 0 corresponding
670
+ to maximum diversity and 1 to minimum diversity.
671
+ Defaults to 0.5.
672
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
673
+
674
+ Returns:
675
+ List[Document]: List of Documents selected by maximal marginal relevance.
676
+ """
677
+ embedding = self.embedding_function.embed_query(query)
678
+ return self.max_marginal_relevance_search_by_vector(
679
+ embedding,
680
+ k=k,
681
+ fetch_k=fetch_k,
682
+ lambda_mult=lambda_mult,
683
+ **kwargs,
684
+ )
685
+
686
+ def max_marginal_relevance_search_with_score(
687
+ self,
688
+ query: str,
689
+ k: int = 4,
690
+ fetch_k: int = 20,
691
+ lambda_mult: float = 0.5,
692
+ filter: Optional[dict] = None,
693
+ **kwargs: Any,
694
+ ) -> List[Tuple[Document, float]]:
695
+ """Return docs selected using the maximal marginal relevance with score.
696
+
697
+ Maximal marginal relevance optimizes for similarity to query AND diversity
698
+ among selected documents.
699
+
700
+ Args:
701
+ query (str): Text to look up documents similar to.
702
+ k (int): Number of Documents to return. Defaults to 4.
703
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
704
+ Defaults to 20.
705
+ lambda_mult (float): Number between 0 and 1 that determines the degree
706
+ of diversity among the results with 0 corresponding
707
+ to maximum diversity and 1 to minimum diversity.
708
+ Defaults to 0.5.
709
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
710
+
711
+ Returns:
712
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
713
+ relevance to the query and score for each.
714
+ """
715
+ embedding = self.embedding_function.embed_query(query)
716
+ docs = self.max_marginal_relevance_search_with_score_by_vector(
717
+ embedding=embedding,
718
+ k=k,
719
+ fetch_k=fetch_k,
720
+ lambda_mult=lambda_mult,
721
+ filter=filter,
722
+ **kwargs,
723
+ )
724
+ return docs
725
+
726
+ def max_marginal_relevance_search_by_vector(
727
+ self,
728
+ embedding: List[float],
729
+ k: int = 4,
730
+ fetch_k: int = 20,
731
+ lambda_mult: float = 0.5,
732
+ filter: Optional[Dict[str, str]] = None,
733
+ **kwargs: Any,
734
+ ) -> List[Document]:
735
+ """Return docs selected using the maximal marginal relevance
736
+ to embedding vector.
737
+
738
+ Maximal marginal relevance optimizes for similarity to query AND diversity
739
+ among selected documents.
740
+
741
+ Args:
742
+ embedding (str): Text to look up documents similar to.
743
+ k (int): Number of Documents to return. Defaults to 4.
744
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
745
+ Defaults to 20.
746
+ lambda_mult (float): Number between 0 and 1 that determines the degree
747
+ of diversity among the results with 0 corresponding
748
+ to maximum diversity and 1 to minimum diversity.
749
+ Defaults to 0.5.
750
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
751
+
752
+ Returns:
753
+ List[Document]: List of Documents selected by maximal marginal relevance.
754
+ """
755
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
756
+ embedding,
757
+ k=k,
758
+ fetch_k=fetch_k,
759
+ lambda_mult=lambda_mult,
760
+ filter=filter,
761
+ **kwargs,
762
+ )
763
+
764
+ return _results_to_docs(docs_and_scores)
765
+
766
+ async def amax_marginal_relevance_search_by_vector(
767
+ self,
768
+ embedding: List[float],
769
+ k: int = 4,
770
+ fetch_k: int = 20,
771
+ lambda_mult: float = 0.5,
772
+ filter: Optional[Dict[str, str]] = None,
773
+ **kwargs: Any,
774
+ ) -> List[Document]:
775
+ """Return docs selected using the maximal marginal relevance."""
776
+
777
+ # This is a temporary workaround to make the similarity search
778
+ # asynchronous. The proper solution is to make the similarity search
779
+ # asynchronous in the vector store implementations.
780
+ func = partial(
781
+ self.max_marginal_relevance_search_by_vector,
782
+ embedding,
783
+ k=k,
784
+ fetch_k=fetch_k,
785
+ lambda_mult=lambda_mult,
786
+ filter=filter,
787
+ **kwargs,
788
+ )
789
+ return await asyncio.get_event_loop().run_in_executor(None, func)