Spaces:
Sleeping
Sleeping
Commit
·
98a9a5f
1
Parent(s):
d22e275
feat: postgres
Browse files- app.py +2 -8
- custom_pgvector.py +35 -45
app.py
CHANGED
@@ -17,7 +17,7 @@ from css import load_css
|
|
17 |
from custom_pgvector import CustomPGVector
|
18 |
from message import Message
|
19 |
|
20 |
-
CONNECTION_STRING = "
|
21 |
|
22 |
st.set_page_config(layout="wide")
|
23 |
|
@@ -27,13 +27,7 @@ chat_column, doc_column = st.columns([2, 1])
|
|
27 |
|
28 |
|
29 |
def connect() -> sqlalchemy.engine.Connection:
|
30 |
-
engine = sqlalchemy.create_engine(CONNECTION_STRING)
|
31 |
-
|
32 |
-
@event.listens_for(engine, "connect")
|
33 |
-
def receive_connect(connection, _):
|
34 |
-
connection.enable_load_extension(True)
|
35 |
-
sqlite_vss.load(connection)
|
36 |
-
connection.enable_load_extension(False)
|
37 |
|
38 |
conn = engine.connect()
|
39 |
return conn
|
|
|
17 |
from custom_pgvector import CustomPGVector
|
18 |
from message import Message
|
19 |
|
20 |
+
CONNECTION_STRING = "postgresql+psycopg2://postgres@/sorbobot?host=localhost"
|
21 |
|
22 |
st.set_page_config(layout="wide")
|
23 |
|
|
|
27 |
|
28 |
|
29 |
def connect() -> sqlalchemy.engine.Connection:
|
30 |
+
engine = sqlalchemy.create_engine(CONNECTION_STRING, pool_pre_ping=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
conn = engine.connect()
|
33 |
return conn
|
custom_pgvector.py
CHANGED
@@ -339,53 +339,43 @@ class CustomPGVector(VectorStore):
|
|
339 |
k: int = 4,
|
340 |
) -> List[Any]:
|
341 |
"""Query the collection."""
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
with matches as (
|
347 |
select
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
)
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
group by article.id
|
372 |
-
order by distance;
|
373 |
-
"""),
|
374 |
-
{"vector": vector, "limit": k}
|
375 |
-
)
|
376 |
-
results = cursor.fetchall()
|
377 |
-
results = pd.DataFrame(
|
378 |
-
results,
|
379 |
-
columns=[
|
380 |
-
"id",
|
381 |
-
"title",
|
382 |
-
"doi",
|
383 |
-
"abstract",
|
384 |
-
"keywords",
|
385 |
-
"authors",
|
386 |
-
"distance",
|
387 |
-
],
|
388 |
-
)
|
389 |
results = results.to_dict(orient="records")
|
390 |
return results
|
391 |
|
|
|
339 |
k: int = 4,
|
340 |
) -> List[Any]:
|
341 |
"""Query the collection."""
|
342 |
+
with Session(self._conn) as session:
|
343 |
+
results = session.execute(
|
344 |
+
text(
|
345 |
+
f"""
|
|
|
346 |
select
|
347 |
+
a.id,
|
348 |
+
a.title,
|
349 |
+
a.doi,
|
350 |
+
a.abstract,
|
351 |
+
string_agg(distinct keyword."name", ',') as keywords,
|
352 |
+
string_agg(distinct author."name", ',') as authors,
|
353 |
+
abstract_embedding <-> '{str(embedding)}' as distance
|
354 |
+
from article a
|
355 |
+
left join article_keyword ON article_keyword.article_id = a.id
|
356 |
+
left join keyword on article_keyword.keyword_id = keyword.id
|
357 |
+
left join article_author ON article_author.article_id = a.id
|
358 |
+
left join author on author.id = article_author.author_id
|
359 |
+
where abstract != 'NaN'
|
360 |
+
GROUP BY a.id
|
361 |
+
ORDER BY distance
|
362 |
+
LIMIT {k};
|
363 |
+
"""
|
364 |
)
|
365 |
+
)
|
366 |
+
results = results.fetchall()
|
367 |
+
results = pd.DataFrame(
|
368 |
+
results,
|
369 |
+
columns=[
|
370 |
+
"id",
|
371 |
+
"title",
|
372 |
+
"doi",
|
373 |
+
"abstract",
|
374 |
+
"keywords",
|
375 |
+
"authors",
|
376 |
+
"distance",
|
377 |
+
],
|
378 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
results = results.to_dict(orient="records")
|
380 |
return results
|
381 |
|