Spaces:
Sleeping
Sleeping
Commit
·
83ed4d1
1
Parent(s):
3378b23
feat: move sql connection outside of custom pgvector
Browse files- app.py +11 -1
- custom_pgvector.py +5 -11
app.py
CHANGED
@@ -11,6 +11,7 @@ from langchain.chains import ConversationalRetrievalChain
|
|
11 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
12 |
from langchain.embeddings import GPT4AllEmbeddings
|
13 |
from message import Message
|
|
|
14 |
|
15 |
CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
|
16 |
|
@@ -21,6 +22,15 @@ st.title("Sorbobot - Le futur de la recherche scientifique interactive")
|
|
21 |
chat_column, doc_column = st.columns([2, 1])
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def initialize_session_state():
|
25 |
if "history" not in st.session_state:
|
26 |
st.session_state.history = []
|
@@ -33,7 +43,7 @@ def initialize_session_state():
|
|
33 |
embedding_function=embeddings,
|
34 |
table_name="article",
|
35 |
column_name="abstract_embedding",
|
36 |
-
|
37 |
)
|
38 |
|
39 |
retriever = db.as_retriever()
|
|
|
11 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
12 |
from langchain.embeddings import GPT4AllEmbeddings
|
13 |
from message import Message
|
14 |
+
import sqlalchemy
|
15 |
|
16 |
CONNECTION_STRING = "postgresql+psycopg2://localhost/sorbobot"
|
17 |
|
|
|
22 |
chat_column, doc_column = st.columns([2, 1])
|
23 |
|
24 |
|
25 |
+
def connect() -> sqlalchemy.engine.Connection:
|
26 |
+
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
27 |
+
conn = engine.connect()
|
28 |
+
return conn
|
29 |
+
|
30 |
+
|
31 |
+
conn = connect()
|
32 |
+
|
33 |
+
|
34 |
def initialize_session_state():
|
35 |
if "history" not in st.session_state:
|
36 |
st.session_state.history = []
|
|
|
43 |
embedding_function=embeddings,
|
44 |
table_name="article",
|
45 |
column_name="abstract_embedding",
|
46 |
+
connection=conn,
|
47 |
)
|
48 |
|
49 |
retriever = db.as_retriever()
|
custom_pgvector.py
CHANGED
@@ -70,7 +70,7 @@ class CustomPGVector(VectorStore):
|
|
70 |
To use, you should have the ``pgvector`` python package installed.
|
71 |
|
72 |
Args:
|
73 |
-
|
74 |
embedding_function: Any embedding function implementing
|
75 |
`langchain.embeddings.base.Embeddings` interface.
|
76 |
table_name: The name of the collection to use. (default: langchain)
|
@@ -87,14 +87,13 @@ class CustomPGVector(VectorStore):
|
|
87 |
from langchain.vectorstores import PGVector
|
88 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
89 |
|
90 |
-
CONNECTION_STRING = "postgresql+psycopg2://hwc@localhost:5432/test3"
|
91 |
COLLECTION_NAME = "state_of_the_union_test"
|
92 |
embeddings = OpenAIEmbeddings()
|
93 |
vectorestore = PGVector.from_documents(
|
94 |
embedding=embeddings,
|
95 |
documents=docs,
|
96 |
table_name=COLLECTION_NAME,
|
97 |
-
|
98 |
)
|
99 |
|
100 |
|
@@ -102,7 +101,7 @@ class CustomPGVector(VectorStore):
|
|
102 |
|
103 |
def __init__(
|
104 |
self,
|
105 |
-
|
106 |
embedding_function: Embeddings,
|
107 |
table_name: str,
|
108 |
column_name: str,
|
@@ -111,7 +110,7 @@ class CustomPGVector(VectorStore):
|
|
111 |
pre_delete_collection: bool = False,
|
112 |
logger: Optional[logging.Logger] = None,
|
113 |
) -> None:
|
114 |
-
self.
|
115 |
self.embedding_function = embedding_function
|
116 |
self.table_name = table_name
|
117 |
self.column_name = column_name
|
@@ -127,7 +126,7 @@ class CustomPGVector(VectorStore):
|
|
127 |
"""
|
128 |
Initialize the store.
|
129 |
"""
|
130 |
-
self._conn = self.connect()
|
131 |
self.create_vector_extension()
|
132 |
|
133 |
self.EmbeddingStore = Article
|
@@ -136,11 +135,6 @@ class CustomPGVector(VectorStore):
|
|
136 |
def embeddings(self) -> Embeddings:
|
137 |
return self.embedding_function
|
138 |
|
139 |
-
def connect(self) -> sqlalchemy.engine.Connection:
|
140 |
-
engine = sqlalchemy.create_engine(self.connection_string)
|
141 |
-
conn = engine.connect()
|
142 |
-
return conn
|
143 |
-
|
144 |
def create_vector_extension(self) -> None:
|
145 |
try:
|
146 |
with Session(self._conn) as session:
|
|
|
70 |
To use, you should have the ``pgvector`` python package installed.
|
71 |
|
72 |
Args:
|
73 |
+
connection: Postgres connection string.
|
74 |
embedding_function: Any embedding function implementing
|
75 |
`langchain.embeddings.base.Embeddings` interface.
|
76 |
table_name: The name of the collection to use. (default: langchain)
|
|
|
87 |
from langchain.vectorstores import PGVector
|
88 |
from langchain.embeddings.openai import OpenAIEmbeddings
|
89 |
|
|
|
90 |
COLLECTION_NAME = "state_of_the_union_test"
|
91 |
embeddings = OpenAIEmbeddings()
|
92 |
vectorestore = PGVector.from_documents(
|
93 |
embedding=embeddings,
|
94 |
documents=docs,
|
95 |
table_name=COLLECTION_NAME,
|
96 |
+
connection=connection,
|
97 |
)
|
98 |
|
99 |
|
|
|
101 |
|
102 |
def __init__(
|
103 |
self,
|
104 |
+
connection: sqlalchemy.engine.Connection,
|
105 |
embedding_function: Embeddings,
|
106 |
table_name: str,
|
107 |
column_name: str,
|
|
|
110 |
pre_delete_collection: bool = False,
|
111 |
logger: Optional[logging.Logger] = None,
|
112 |
) -> None:
|
113 |
+
self._conn = connection
|
114 |
self.embedding_function = embedding_function
|
115 |
self.table_name = table_name
|
116 |
self.column_name = column_name
|
|
|
126 |
"""
|
127 |
Initialize the store.
|
128 |
"""
|
129 |
+
# self._conn = self.connect()
|
130 |
self.create_vector_extension()
|
131 |
|
132 |
self.EmbeddingStore = Article
|
|
|
135 |
def embeddings(self) -> Embeddings:
|
136 |
return self.embedding_function
|
137 |
|
|
|
|
|
|
|
|
|
|
|
138 |
def create_vector_extension(self) -> None:
|
139 |
try:
|
140 |
with Session(self._conn) as session:
|