Léo Bourrel commited on
Commit
8505f96
·
1 Parent(s): f419f72

feat: add v1 of distance limit

Browse files
Files changed (2) hide show
  1. models/distance.py +7 -0
  2. vector_store.py +6 -3
models/distance.py CHANGED
@@ -1,6 +1,13 @@
1
  import enum
2
 
3
 
 
 
 
 
 
 
 
4
  class DistanceStrategy(str, enum.Enum):
5
  """Enumerator of the Distance strategies."""
6
 
 
1
  import enum
2
 
3
 
4
+ distance_strategy_limit = {
5
+ "l2": 1.05,
6
+ "cosine": 0.6,
7
+ "inner": 1.0,
8
+ }
9
+
10
+
11
  class DistanceStrategy(str, enum.Enum):
12
  """Enumerator of the Distance strategies."""
13
 
vector_store.py CHANGED
@@ -14,7 +14,7 @@ from sqlalchemy import delete, text
14
  from sqlalchemy.orm import Session
15
 
16
  from model import Article
17
- from models.distance import DistanceStrategy
18
  from utils import str_to_list
19
 
20
  DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
@@ -252,6 +252,8 @@ class CustomVectorStore(VectorStore):
252
  k: int = 4,
253
  ) -> List[Any]:
254
  """Query the collection."""
 
 
255
  with Session(self._conn) as session:
256
  results = session.execute(
257
  text(
@@ -272,10 +274,11 @@ class CustomVectorStore(VectorStore):
272
  left join author on author.id = article_author.author_id
273
  where
274
  abstract_en != '' and
275
- abstract_en != 'None'
 
276
  GROUP BY a.id
277
  ORDER BY distance
278
- LIMIT {k};
279
  """
280
  )
281
  )
 
14
  from sqlalchemy.orm import Session
15
 
16
  from model import Article
17
+ from models.distance import DistanceStrategy, distance_strategy_limit
18
  from utils import str_to_list
19
 
20
  DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
 
252
  k: int = 4,
253
  ) -> List[Any]:
254
  """Query the collection."""
255
+
256
+ limit = distance_strategy_limit[self._distance_strategy]
257
  with Session(self._conn) as session:
258
  results = session.execute(
259
  text(
 
274
  left join author on author.id = article_author.author_id
275
  where
276
  abstract_en != '' and
277
+ abstract_en != 'None' and
278
+ abstract_embedding_en {self.distance_strategy} '{str(embedding)}' < {limit}
279
  GROUP BY a.id
280
  ORDER BY distance
281
+ LIMIT 100;
282
  """
283
  )
284
  )