Léo Bourrel commited on
Commit
fc40941
·
1 Parent(s): 24d1b6f

feat: remove max marginal relevance search

Browse files
Files changed (1) hide show
  1. custom_pgvector.py +0 -217
custom_pgvector.py CHANGED
@@ -24,7 +24,6 @@ from langchain.schema.embeddings import Embeddings
24
  from langchain.utils import get_from_dict_or_env
25
  from langchain.vectorstores.base import VectorStore
26
  from langchain.vectorstores.pgvector import BaseModel
27
- from langchain.vectorstores.utils import maximal_marginal_relevance
28
  from pgvector.sqlalchemy import Vector
29
  from sqlalchemy import delete
30
  from sqlalchemy.orm import Session, declarative_base, relationship
@@ -110,7 +109,6 @@ class CustomPGVector(VectorStore):
110
  distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
111
  pre_delete_collection: bool = False,
112
  logger: Optional[logging.Logger] = None,
113
- relevance_score_fn: Optional[Callable[[float], float]] = None,
114
  ) -> None:
115
  self.connection_string = connection_string
116
  self.embedding_function = embedding_function
@@ -120,7 +118,6 @@ class CustomPGVector(VectorStore):
120
  self._distance_strategy = distance_strategy
121
  self.pre_delete_collection = pre_delete_collection
122
  self.logger = logger or logging.getLogger(__name__)
123
- self.override_relevance_score_fn = relevance_score_fn
124
  self.__post_init__()
125
 
126
  def __post_init__(
@@ -592,217 +589,3 @@ class CustomPGVector(VectorStore):
592
  ) -> str:
593
  """Return connection string from database parameters."""
594
  return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
595
-
596
- def _select_relevance_score_fn(self) -> Callable[[float], float]:
597
- """
598
- The 'correct' relevance function
599
- may differ depending on a few things, including:
600
- - the distance / similarity metric used by the VectorStore
601
- - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
602
- - embedding dimensionality
603
- - etc.
604
- """
605
- if self.override_relevance_score_fn is not None:
606
- return self.override_relevance_score_fn
607
-
608
- # Default strategy is to rely on distance strategy provided
609
- # in vectorstore constructor
610
- if self._distance_strategy == DistanceStrategy.COSINE:
611
- return self._cosine_relevance_score_fn
612
- elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
613
- return self._euclidean_relevance_score_fn
614
- elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
615
- return self._max_inner_product_relevance_score_fn
616
- else:
617
- raise ValueError(
618
- "No supported normalization function"
619
- f" for distance_strategy of {self._distance_strategy}."
620
- "Consider providing relevance_score_fn to PGVector constructor."
621
- )
622
-
623
- def max_marginal_relevance_search_with_score_by_vector(
624
- self,
625
- embedding: List[float],
626
- k: int = 4,
627
- fetch_k: int = 20,
628
- lambda_mult: float = 0.5,
629
- filter: Optional[Dict[str, str]] = None,
630
- **kwargs: Any,
631
- ) -> List[Tuple[Document, float]]:
632
- """Return docs selected using the maximal marginal relevance with score
633
- to embedding vector.
634
-
635
- Maximal marginal relevance optimizes for similarity to query AND diversity
636
- among selected documents.
637
-
638
- Args:
639
- embedding: Embedding to look up documents similar to.
640
- k (int): Number of Documents to return. Defaults to 4.
641
- fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
642
- Defaults to 20.
643
- lambda_mult (float): Number between 0 and 1 that determines the degree
644
- of diversity among the results with 0 corresponding
645
- to maximum diversity and 1 to minimum diversity.
646
- Defaults to 0.5.
647
- filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
648
-
649
- Returns:
650
- List[Tuple[Document, float]]: List of Documents selected by maximal marginal
651
- relevance to the query and score for each.
652
- """
653
- results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
654
-
655
- embedding_list = [result.EmbeddingStore.embedding for result in results]
656
-
657
- mmr_selected = maximal_marginal_relevance(
658
- np.array(embedding, dtype=np.float32),
659
- embedding_list,
660
- k=k,
661
- lambda_mult=lambda_mult,
662
- )
663
-
664
- candidates = self._results_to_docs_and_scores(results)
665
-
666
- return [r for i, r in enumerate(candidates) if i in mmr_selected]
667
-
668
- def max_marginal_relevance_search(
669
- self,
670
- query: str,
671
- k: int = 4,
672
- fetch_k: int = 20,
673
- lambda_mult: float = 0.5,
674
- filter: Optional[Dict[str, str]] = None,
675
- **kwargs: Any,
676
- ) -> List[Document]:
677
- """Return docs selected using the maximal marginal relevance.
678
-
679
- Maximal marginal relevance optimizes for similarity to query AND diversity
680
- among selected documents.
681
-
682
- Args:
683
- query (str): Text to look up documents similar to.
684
- k (int): Number of Documents to return. Defaults to 4.
685
- fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
686
- Defaults to 20.
687
- lambda_mult (float): Number between 0 and 1 that determines the degree
688
- of diversity among the results with 0 corresponding
689
- to maximum diversity and 1 to minimum diversity.
690
- Defaults to 0.5.
691
- filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
692
-
693
- Returns:
694
- List[Document]: List of Documents selected by maximal marginal relevance.
695
- """
696
- embedding = self.embedding_function.embed_query(query)
697
- return self.max_marginal_relevance_search_by_vector(
698
- embedding,
699
- k=k,
700
- fetch_k=fetch_k,
701
- lambda_mult=lambda_mult,
702
- **kwargs,
703
- )
704
-
705
- def max_marginal_relevance_search_with_score(
706
- self,
707
- query: str,
708
- k: int = 4,
709
- fetch_k: int = 20,
710
- lambda_mult: float = 0.5,
711
- filter: Optional[dict] = None,
712
- **kwargs: Any,
713
- ) -> List[Tuple[Document, float]]:
714
- """Return docs selected using the maximal marginal relevance with score.
715
-
716
- Maximal marginal relevance optimizes for similarity to query AND diversity
717
- among selected documents.
718
-
719
- Args:
720
- query (str): Text to look up documents similar to.
721
- k (int): Number of Documents to return. Defaults to 4.
722
- fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
723
- Defaults to 20.
724
- lambda_mult (float): Number between 0 and 1 that determines the degree
725
- of diversity among the results with 0 corresponding
726
- to maximum diversity and 1 to minimum diversity.
727
- Defaults to 0.5.
728
- filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
729
-
730
- Returns:
731
- List[Tuple[Document, float]]: List of Documents selected by maximal marginal
732
- relevance to the query and score for each.
733
- """
734
- embedding = self.embedding_function.embed_query(query)
735
- docs = self.max_marginal_relevance_search_with_score_by_vector(
736
- embedding=embedding,
737
- k=k,
738
- fetch_k=fetch_k,
739
- lambda_mult=lambda_mult,
740
- filter=filter,
741
- **kwargs,
742
- )
743
- return docs
744
-
745
- def max_marginal_relevance_search_by_vector(
746
- self,
747
- embedding: List[float],
748
- k: int = 4,
749
- fetch_k: int = 20,
750
- lambda_mult: float = 0.5,
751
- filter: Optional[Dict[str, str]] = None,
752
- **kwargs: Any,
753
- ) -> List[Document]:
754
- """Return docs selected using the maximal marginal relevance
755
- to embedding vector.
756
-
757
- Maximal marginal relevance optimizes for similarity to query AND diversity
758
- among selected documents.
759
-
760
- Args:
761
- embedding (str): Text to look up documents similar to.
762
- k (int): Number of Documents to return. Defaults to 4.
763
- fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
764
- Defaults to 20.
765
- lambda_mult (float): Number between 0 and 1 that determines the degree
766
- of diversity among the results with 0 corresponding
767
- to maximum diversity and 1 to minimum diversity.
768
- Defaults to 0.5.
769
- filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
770
-
771
- Returns:
772
- List[Document]: List of Documents selected by maximal marginal relevance.
773
- """
774
- docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
775
- embedding,
776
- k=k,
777
- fetch_k=fetch_k,
778
- lambda_mult=lambda_mult,
779
- filter=filter,
780
- **kwargs,
781
- )
782
-
783
- return _results_to_docs(docs_and_scores)
784
-
785
- async def amax_marginal_relevance_search_by_vector(
786
- self,
787
- embedding: List[float],
788
- k: int = 4,
789
- fetch_k: int = 20,
790
- lambda_mult: float = 0.5,
791
- filter: Optional[Dict[str, str]] = None,
792
- **kwargs: Any,
793
- ) -> List[Document]:
794
- """Return docs selected using the maximal marginal relevance."""
795
-
796
- # This is a temporary workaround to make the similarity search
797
- # asynchronous. The proper solution is to make the similarity search
798
- # asynchronous in the vector store implementations.
799
- func = partial(
800
- self.max_marginal_relevance_search_by_vector,
801
- embedding,
802
- k=k,
803
- fetch_k=fetch_k,
804
- lambda_mult=lambda_mult,
805
- filter=filter,
806
- **kwargs,
807
- )
808
- return await asyncio.get_event_loop().run_in_executor(None, func)
 
24
  from langchain.utils import get_from_dict_or_env
25
  from langchain.vectorstores.base import VectorStore
26
  from langchain.vectorstores.pgvector import BaseModel
 
27
  from pgvector.sqlalchemy import Vector
28
  from sqlalchemy import delete
29
  from sqlalchemy.orm import Session, declarative_base, relationship
 
109
  distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
110
  pre_delete_collection: bool = False,
111
  logger: Optional[logging.Logger] = None,
 
112
  ) -> None:
113
  self.connection_string = connection_string
114
  self.embedding_function = embedding_function
 
118
  self._distance_strategy = distance_strategy
119
  self.pre_delete_collection = pre_delete_collection
120
  self.logger = logger or logging.getLogger(__name__)
 
121
  self.__post_init__()
122
 
123
  def __post_init__(
 
589
  ) -> str:
590
  """Return connection string from database parameters."""
591
  return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"