Léo Bourrel commited on
Commit
24d1b6f
·
1 Parent(s): b8c8744

feat: replace sqlalchemy query by executing SQL

Browse files
Files changed (1) hide show
  1. custom_pgvector.py +33 -14
custom_pgvector.py CHANGED
@@ -1,5 +1,5 @@
1
  from __future__ import annotations
2
-
3
  import asyncio
4
  import contextlib
5
  import enum
@@ -28,6 +28,7 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
28
  from pgvector.sqlalchemy import Vector
29
  from sqlalchemy import delete
30
  from sqlalchemy.orm import Session, declarative_base, relationship
 
31
 
32
 
33
  class DistanceStrategy(str, enum.Enum):
@@ -348,16 +349,19 @@ class CustomPGVector(VectorStore):
348
  docs = [
349
  (
350
  Document(
351
- page_content=result.Article.abstract,
352
  metadata={
353
- "id": result.Article.id,
354
- "title": result.Article.title,
355
- "doi": result.Article.doi,
 
 
 
356
  },
357
  ),
358
  result.distance if self.embedding_function is not None else None,
359
  )
360
- for result in results
361
  ]
362
  return docs
363
 
@@ -369,16 +373,31 @@ class CustomPGVector(VectorStore):
369
  ) -> List[Any]:
370
  """Query the collection."""
371
  with Session(self._conn) as session:
372
- results: List[Any] = (
373
- session.query(
374
- self.EmbeddingStore,
375
- self.distance_strategy(embedding).label("distance"), # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  )
377
- .order_by(sqlalchemy.asc("distance"))
378
- .limit(k)
379
- .all()
380
  )
381
- print(results)
 
382
  return results
383
 
384
  def similarity_search_by_vector(
 
1
  from __future__ import annotations
2
+ import pandas as pd
3
  import asyncio
4
  import contextlib
5
  import enum
 
28
  from pgvector.sqlalchemy import Vector
29
  from sqlalchemy import delete
30
  from sqlalchemy.orm import Session, declarative_base, relationship
31
+ from sqlalchemy import text
32
 
33
 
34
  class DistanceStrategy(str, enum.Enum):
 
349
  docs = [
350
  (
351
  Document(
352
+ page_content=result.abstract,
353
  metadata={
354
+ "id": result.id,
355
+ "title": result.title,
356
+ "authors": result.authors,
357
+ "doi": result.doi,
358
+ "keywords": results.keywords,
359
+ "distance": results.distance,
360
  },
361
  ),
362
  result.distance if self.embedding_function is not None else None,
363
  )
364
+ for result in results.itertuples()
365
  ]
366
  return docs
367
 
 
373
  ) -> List[Any]:
374
  """Query the collection."""
375
  with Session(self._conn) as session:
376
+ results = session.execute(
377
+ text(
378
+ f"""
379
+ select
380
+ a.id,
381
+ a.title,
382
+ a.doi,
383
+ a.abstract,
384
+ string_agg(distinct keyword."name", ',') as keywords,
385
+ string_agg(distinct author."name", ',') as authors,
386
+ abstract_embedding <-> '{str(embedding)}' as distance
387
+ from article a
388
+ left join article_keyword ON article_keyword.article_id = a.id
389
+ left join keyword on article_keyword.keyword_id = keyword.id
390
+ left join article_author ON article_author.article_id = a.id
391
+ left join author on author.id = article_author.author_id
392
+ where abstract != 'NaN'
393
+ GROUP BY a.id
394
+ ORDER BY distance
395
+ LIMIT {k};
396
+ """
397
  )
 
 
 
398
  )
399
+ results = results.fetchall()
400
+ results = pd.DataFrame(results, columns=["id", "title", "doi", "abstract", "keywords", "authors", "distance"])
401
  return results
402
 
403
  def similarity_search_by_vector(