leo-bourrel commited on
Commit
98a9a5f
·
1 Parent(s): d22e275

feat: postgres

Browse files
Files changed (2) hide show
  1. app.py +2 -8
  2. custom_pgvector.py +35 -45
app.py CHANGED
@@ -17,7 +17,7 @@ from css import load_css
17
  from custom_pgvector import CustomPGVector
18
  from message import Message
19
 
20
- CONNECTION_STRING = "sqlite:///data/sorbobot.db"
21
 
22
  st.set_page_config(layout="wide")
23
 
@@ -27,13 +27,7 @@ chat_column, doc_column = st.columns([2, 1])
27
 
28
 
29
  def connect() -> sqlalchemy.engine.Connection:
30
- engine = sqlalchemy.create_engine(CONNECTION_STRING)
31
-
32
- @event.listens_for(engine, "connect")
33
- def receive_connect(connection, _):
34
- connection.enable_load_extension(True)
35
- sqlite_vss.load(connection)
36
- connection.enable_load_extension(False)
37
 
38
  conn = engine.connect()
39
  return conn
 
17
  from custom_pgvector import CustomPGVector
18
  from message import Message
19
 
20
+ CONNECTION_STRING = "postgresql+psycopg2://postgres@/sorbobot?host=localhost"
21
 
22
  st.set_page_config(layout="wide")
23
 
 
27
 
28
 
29
  def connect() -> sqlalchemy.engine.Connection:
30
+ engine = sqlalchemy.create_engine(CONNECTION_STRING, pool_pre_ping=True)
 
 
 
 
 
 
31
 
32
  conn = engine.connect()
33
  return conn
custom_pgvector.py CHANGED
@@ -339,53 +339,43 @@ class CustomPGVector(VectorStore):
339
  k: int = 4,
340
  ) -> List[Any]:
341
  """Query the collection."""
342
- vector = bytearray(struct.pack("f" * len(embedding), *embedding))
343
-
344
- cursor = self._conn.execute(
345
- text("""
346
- with matches as (
347
  select
348
- rowid,
349
- distance
350
- from vss_article
351
- where vss_search(
352
- abstract_embedding,
353
- :vector
354
- )
355
- limit 5
 
 
 
 
 
 
 
 
 
356
  )
357
- select
358
- article.id,
359
- article.title,
360
- article.doi,
361
- article.abstract,
362
- group_concat(keyword."name", ',') as keywords,
363
- group_concat(author."name", ',') as authors,
364
- matches.distance
365
- from matches
366
- left join article on matches.rowid = article.rowid
367
- left join article_keyword ak ON ak.article_id = article.id
368
- left join keyword on ak.keyword_id = keyword.id
369
- left join article_author ON article_author.article_id = article.id
370
- left join author on author.id = article_author.author_id
371
- group by article.id
372
- order by distance;
373
- """),
374
- {"vector": vector, "limit": k}
375
- )
376
- results = cursor.fetchall()
377
- results = pd.DataFrame(
378
- results,
379
- columns=[
380
- "id",
381
- "title",
382
- "doi",
383
- "abstract",
384
- "keywords",
385
- "authors",
386
- "distance",
387
- ],
388
- )
389
  results = results.to_dict(orient="records")
390
  return results
391
 
 
339
  k: int = 4,
340
  ) -> List[Any]:
341
  """Query the collection."""
342
+ with Session(self._conn) as session:
343
+ results = session.execute(
344
+ text(
345
+ f"""
 
346
  select
347
+ a.id,
348
+ a.title,
349
+ a.doi,
350
+ a.abstract,
351
+ string_agg(distinct keyword."name", ',') as keywords,
352
+ string_agg(distinct author."name", ',') as authors,
353
+ abstract_embedding <-> '{str(embedding)}' as distance
354
+ from article a
355
+ left join article_keyword ON article_keyword.article_id = a.id
356
+ left join keyword on article_keyword.keyword_id = keyword.id
357
+ left join article_author ON article_author.article_id = a.id
358
+ left join author on author.id = article_author.author_id
359
+ where abstract != 'NaN'
360
+ GROUP BY a.id
361
+ ORDER BY distance
362
+ LIMIT {k};
363
+ """
364
  )
365
+ )
366
+ results = results.fetchall()
367
+ results = pd.DataFrame(
368
+ results,
369
+ columns=[
370
+ "id",
371
+ "title",
372
+ "doi",
373
+ "abstract",
374
+ "keywords",
375
+ "authors",
376
+ "distance",
377
+ ],
378
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  results = results.to_dict(orient="records")
380
  return results
381