leo-bourrel commited on
Commit
83ed4d1
·
1 Parent(s): 3378b23

feat: move sql connection outside of custom pgvector

Browse files
Files changed (2) hide show
  1. app.py +11 -1
  2. custom_pgvector.py +5 -11
app.py CHANGED
@@ -11,6 +11,7 @@ from langchain.chains import ConversationalRetrievalChain
11
  from langchain.chains.conversation.memory import ConversationBufferMemory
12
  from langchain.embeddings import GPT4AllEmbeddings
13
  from message import Message
 
14
 
15
  CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
16
 
@@ -21,6 +22,15 @@ st.title("Sorbobot - Le futur de la recherche scientifique interactive")
21
  chat_column, doc_column = st.columns([2, 1])
22
 
23
 
 
 
 
 
 
 
 
 
 
24
  def initialize_session_state():
25
  if "history" not in st.session_state:
26
  st.session_state.history = []
@@ -33,7 +43,7 @@ def initialize_session_state():
33
  embedding_function=embeddings,
34
  table_name="article",
35
  column_name="abstract_embedding",
36
- connection_string=CONNECTION_STRING,
37
  )
38
 
39
  retriever = db.as_retriever()
 
11
  from langchain.chains.conversation.memory import ConversationBufferMemory
12
  from langchain.embeddings import GPT4AllEmbeddings
13
  from message import Message
14
+ import sqlalchemy
15
 
16
  CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
17
 
 
22
  chat_column, doc_column = st.columns([2, 1])
23
 
24
 
25
+ def connect() -> sqlalchemy.engine.Connection:
26
+ engine = sqlalchemy.create_engine(CONNECTION_STRING)
27
+ conn = engine.connect()
28
+ return conn
29
+
30
+
31
+ conn = connect()
32
+
33
+
34
  def initialize_session_state():
35
  if "history" not in st.session_state:
36
  st.session_state.history = []
 
43
  embedding_function=embeddings,
44
  table_name="article",
45
  column_name="abstract_embedding",
46
+ connection=conn,
47
  )
48
 
49
  retriever = db.as_retriever()
custom_pgvector.py CHANGED
@@ -70,7 +70,7 @@ class CustomPGVector(VectorStore):
70
  To use, you should have the ``pgvector`` python package installed.
71
 
72
  Args:
73
- connection_string: Postgres connection string.
74
  embedding_function: Any embedding function implementing
75
  `langchain.embeddings.base.Embeddings` interface.
76
  table_name: The name of the collection to use. (default: langchain)
@@ -87,14 +87,13 @@ class CustomPGVector(VectorStore):
87
  from langchain.vectorstores import PGVector
88
  from langchain.embeddings.openai import OpenAIEmbeddings
89
 
90
- CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3"
91
  COLLECTION_NAME = "state_of_the_union_test"
92
  embeddings = OpenAIEmbeddings()
93
  vectorestore = PGVector.from_documents(
94
  embedding=embeddings,
95
  documents=docs,
96
  table_name=COLLECTION_NAME,
97
- connection_string=CONNECTION_STRING,
98
  )
99
 
100
 
@@ -102,7 +101,7 @@ class CustomPGVector(VectorStore):
102
 
103
  def __init__(
104
  self,
105
- connection_string: str,
106
  embedding_function: Embeddings,
107
  table_name: str,
108
  column_name: str,
@@ -111,7 +110,7 @@ class CustomPGVector(VectorStore):
111
  pre_delete_collection: bool = False,
112
  logger: Optional[logging.Logger] = None,
113
  ) -> None:
114
- self.connection_string = connection_string
115
  self.embedding_function = embedding_function
116
  self.table_name = table_name
117
  self.column_name = column_name
@@ -127,7 +126,7 @@ class CustomPGVector(VectorStore):
127
  """
128
  Initialize the store.
129
  """
130
- self._conn = self.connect()
131
  self.create_vector_extension()
132
 
133
  self.EmbeddingStore = Article
@@ -136,11 +135,6 @@ class CustomPGVector(VectorStore):
136
  def embeddings(self) -> Embeddings:
137
  return self.embedding_function
138
 
139
- def connect(self) -> sqlalchemy.engine.Connection:
140
- engine = sqlalchemy.create_engine(self.connection_string)
141
- conn = engine.connect()
142
- return conn
143
-
144
  def create_vector_extension(self) -> None:
145
  try:
146
  with Session(self._conn) as session:
 
70
  To use, you should have the ``pgvector`` python package installed.
71
 
72
  Args:
73
+ connection: Postgres connection string.
74
  embedding_function: Any embedding function implementing
75
  `langchain.embeddings.base.Embeddings` interface.
76
  table_name: The name of the collection to use. (default: langchain)
 
87
  from langchain.vectorstores import PGVector
88
  from langchain.embeddings.openai import OpenAIEmbeddings
89
 
 
90
  COLLECTION_NAME = "state_of_the_union_test"
91
  embeddings = OpenAIEmbeddings()
92
  vectorestore = PGVector.from_documents(
93
  embedding=embeddings,
94
  documents=docs,
95
  table_name=COLLECTION_NAME,
96
+ connection=connection,
97
  )
98
 
99
 
 
101
 
102
  def __init__(
103
  self,
104
+ connection: sqlalchemy.engine.Connection,
105
  embedding_function: Embeddings,
106
  table_name: str,
107
  column_name: str,
 
110
  pre_delete_collection: bool = False,
111
  logger: Optional[logging.Logger] = None,
112
  ) -> None:
113
+ self._conn = connection
114
  self.embedding_function = embedding_function
115
  self.table_name = table_name
116
  self.column_name = column_name
 
126
  """
127
  Initialize the store.
128
  """
129
+ # self._conn = self.connect()
130
  self.create_vector_extension()
131
 
132
  self.EmbeddingStore = Article
 
135
  def embeddings(self) -> Embeddings:
136
  return self.embedding_function
137
 
 
 
 
 
 
138
  def create_vector_extension(self) -> None:
139
  try:
140
  with Session(self._conn) as session: