leo-bourrel commited on
Commit
169e727
·
1 Parent(s): 4278112

feat: handle different distance strategy

Browse files
Files changed (1) hide show
  1. vector_store.py +6 -6
vector_store.py CHANGED
@@ -17,7 +17,7 @@ from model import Article
17
  from models.distance import DistanceStrategy
18
  from utils import str_to_list
19
 
20
- DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
21
 
22
  _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
23
 
@@ -200,13 +200,13 @@ class CustomVectorStore(VectorStore):
200
  return docs
201
 
202
  @property
203
- def distance_strategy(self) -> Any:
204
  if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
205
- return self.EmbeddingStore.embedding.l2_distance
206
  elif self._distance_strategy == DistanceStrategy.COSINE:
207
- return self.EmbeddingStore.embedding.cosine_distance
208
  elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
209
- return self.EmbeddingStore.embedding.max_inner_product
210
  else:
211
  raise ValueError(
212
  f"Got unexpected value for distance: {self._distance_strategy}. "
@@ -265,7 +265,7 @@ class CustomVectorStore(VectorStore):
265
  a.abstract_en,
266
  string_agg(distinct keyword."name", ',') as keywords,
267
  string_agg(distinct author."name", ',') as authors,
268
- abstract_embedding_en <-> '{str(embedding)}' as distance
269
  from article a
270
  left join article_keyword ON article_keyword.article_id = a.id
271
  left join keyword on article_keyword.keyword_id = keyword.id
 
17
  from models.distance import DistanceStrategy
18
  from utils import str_to_list
19
 
20
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.EUCLIDEAN
21
 
22
  _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
23
 
 
200
  return docs
201
 
202
  @property
203
+ def distance_strategy(self) -> str | None:
204
  if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
205
+ return "<->"
206
  elif self._distance_strategy == DistanceStrategy.COSINE:
207
+ return "<=>"
208
  elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
209
+ return "<#>"
210
  else:
211
  raise ValueError(
212
  f"Got unexpected value for distance: {self._distance_strategy}. "
 
265
  a.abstract_en,
266
  string_agg(distinct keyword."name", ',') as keywords,
267
  string_agg(distinct author."name", ',') as authors,
268
+ abstract_embedding_en {self.distance_strategy} '{str(embedding)}' as distance
269
  from article a
270
  left join article_keyword ON article_keyword.article_id = a.id
271
  left join keyword on article_keyword.keyword_id = keyword.id