11"""Wrapper around in-memory DocArray store."""
22from __future__ import annotations
33
4- from operator import itemgetter
54from typing import List , Optional , Any , Tuple , Iterable , Type , Callable , Sequence , TYPE_CHECKING
5+ from docarray .typing import NdArray
66
77from langchain .embeddings .base import Embeddings
8- from langchain .schema import Document
9- from langchain .vectorstores import VectorStore
108from langchain .vectorstores .base import VST
11- from langchain .vectorstores .utils import maximal_marginal_relevance
12-
13- from docarray import BaseDoc
14- from docarray .typing import NdArray
9+ from langchain .vectorstores .vector_store_from_doc_index import VecStoreFromDocIndex , _check_docarray_import
1510
1611
17- class HnswLib (VectorStore ):
12+ class HnswLib (VecStoreFromDocIndex ):
1813 """Wrapper around HnswLib storage.
1914
20- To use it, you should have the ``docarray`` package with version >=0.30 .0 installed.
15+ To use it, you should have the ``docarray`` package with version >=0.31 .0 installed.
2116 """
2217 def __init__ (
2318 self ,
24- work_dir : str ,
25- n_dim : int ,
2619 texts : List [str ],
2720 embedding : Embeddings ,
21+ work_dir : str ,
22+ n_dim : int ,
2823 metadatas : Optional [List [dict ]],
29- sim_metric : str = 'cosine' ,
30- kwargs : dict = None
24+ dist_metric : str = 'cosine' ,
25+ ** kwargs ,
3126 ) -> None :
32- """Initialize HnswLib store."""
33- try :
34- import docarray
35- da_version = docarray .__version__ .split ('.' )
36- if int (da_version [0 ]) == 0 and int (da_version [1 ]) <= 21 :
37- raise ValueError (
38- f'To use the HnswLib VectorStore the docarray version >=0.30.0 is expected, '
39- f'received: { docarray .__version__ } .'
40- f'To upgrade, please run: `pip install -U docarray`.'
41- )
42- else :
43- from docarray import DocList
44- from docarray .index import HnswDocumentIndex
45- except ImportError :
46- raise ImportError (
47- "Could not import docarray python package. "
48- "Please install it with `pip install -U docarray`."
49- )
27+ """Initialize HnswLib store.
28+
29+ Args:
30+ texts (List[str]): Text data.
31+ embedding (Embeddings): Embedding function.
32+ metadatas (Optional[List[dict]]): Metadata for each text if it exists.
33+ Defaults to None.
34+ work_dir (str): path to the location where all the data will be stored.
35+ n_dim (int): dimension of an embedding.
36+ dist_metric (str): Distance metric for HnswLib can be one of: 'cosine',
37+ 'ip', and 'l2'. Defaults to 'cosine'.
38+ """
39+ _check_docarray_import ()
40+ from docarray .index import HnswDocumentIndex
41+
5042 try :
5143 import google .protobuf
5244 except ImportError :
@@ -55,27 +47,13 @@ def __init__(
5547 "Please install it with `pip install -U protobuf`."
5648 )
5749
58- if metadatas is None :
59- metadatas = [{} for _ in range (len (texts ))]
60-
61- self .embedding = embedding
62-
63- self .doc_cls = self ._get_doc_cls (n_dim , sim_metric )
64- self .doc_index = HnswDocumentIndex [self .doc_cls ](work_dir = work_dir )
65- embeddings = self .embedding .embed_documents (texts )
66- docs = DocList [self .doc_cls ](
67- [
68- self .doc_cls (
69- text = t ,
70- embedding = e ,
71- metadata = m ,
72- ) for t , m , e in zip (texts , metadatas , embeddings )
73- ]
74- )
75- self .doc_index .index (docs )
50+ doc_cls = self ._get_doc_cls (n_dim , dist_metric )
51+ doc_index = HnswDocumentIndex [doc_cls ](work_dir = work_dir )
52+ super ().__init__ (doc_index , texts , embedding , metadatas )
7653
7754 @staticmethod
7855 def _get_doc_cls (n_dim : int , sim_metric : str ):
56+ from docarray import BaseDoc
7957 from pydantic import Field
8058
8159 class DocArrayDoc (BaseDoc ):
@@ -93,6 +71,7 @@ def from_texts(
9371 metadatas : Optional [List [dict ]] = None ,
9472 work_dir : str = None ,
9573 n_dim : int = None ,
74+ dist_metric : str = 'cosine' ,
9675 ** kwargs : Any
9776 ) -> HnswLib :
9877
@@ -107,129 +86,6 @@ def from_texts(
10786 texts = texts ,
10887 embedding = embedding ,
10988 metadatas = metadatas ,
110- kwargs = kwargs
89+ dist_metric = dist_metric ,
90+ kwargs = kwargs ,
11191 )
112-
113- def add_texts (
114- self ,
115- texts : Iterable [str ],
116- metadatas : Optional [List [dict ]] = None ,
117- ** kwargs : Any
118- ) -> List [str ]:
119- """Run more texts through the embeddings and add to the vectorstore.
120-
121- Args:
122- texts: Iterable of strings to add to the vectorstore.
123- metadatas: Optional list of metadatas associated with the texts.
124-
125- Returns:
126- List of ids from adding the texts into the vectorstore.
127- """
128- if metadatas is None :
129- metadatas = [{} for _ in range (len (list (texts )))]
130-
131- ids = []
132- embeddings = self .embedding .embed_documents (texts )
133- for t , m , e in zip (texts , metadatas , embeddings ):
134- doc = self .doc_cls (
135- text = t ,
136- embedding = e ,
137- metadata = m
138- )
139- self .doc_index .index (doc )
140- ids .append (doc .id ) # TODO return index of self.docs ?
141-
142- return ids
143-
144- def similarity_search_with_score (
145- self , query : str , k : int = 4 , ** kwargs : Any
146- ) -> List [Tuple [Document , float ]]:
147- """Return docs most similar to query.
148-
149- Args:
150- query: Text to look up documents similar to.
151- k: Number of Documents to return. Defaults to 4.
152-
153- Returns:
154- List of Documents most similar to the query and score for each.
155- """
156- query_embedding = self .embedding .embed_query (query )
157- query_embedding = [1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. ]
158- print (f"query_embedding = { query_embedding } " )
159- query_doc = self .doc_cls (embedding = query_embedding )
160- docs , scores = self .doc_index .find (query_doc , search_field = 'embedding' , limit = k )
161-
162- result = [(Document (page_content = doc .text ), score ) for doc , score in zip (docs , scores )]
163- return result
164-
165- def similarity_search (
166- self , query : str , k : int = 4 , ** kwargs : Any
167- ) -> List [Document ]:
168- """Return docs most similar to query.
169-
170- Args:
171- query: Text to look up documents similar to.
172- k: Number of Documents to return. Defaults to 4.
173-
174- Returns:
175- List of Documents most similar to the query.
176- """
177- results = self .similarity_search_with_score (query , k )
178- return list (map (itemgetter (0 ), results ))
179-
180- def _similarity_search_with_relevance_scores (
181- self ,
182- query : str ,
183- k : int = 4 ,
184- ** kwargs : Any ,
185- ) -> List [Tuple [Document , float ]]:
186- """Return docs and relevance scores, normalized on a scale from 0 to 1.
187-
188- 0 is dissimilar, 1 is most similar.
189- """
190- raise NotImplementedError
191-
192- def similarity_search_by_vector (self , embedding : List [float ], k : int = 4 , ** kwargs : Any ) -> List [Document ]:
193- """Return docs most similar to embedding vector.
194-
195- Args:
196- embedding: Embedding to look up documents similar to.
197- k: Number of Documents to return. Defaults to 4.
198-
199- Returns:
200- List of Documents most similar to the query vector.
201- """
202-
203- query_doc = self .doc_cls (embedding = embedding )
204- docs = self .doc_index .find (query_doc , search_field = 'embedding' , limit = k ).documents
205-
206- result = [Document (page_content = doc .text ) for doc in docs ]
207- return result
208-
209- def max_marginal_relevance_search (
210- self , query : str , k : int = 4 , fetch_k : int = 20 , ** kwargs : Any
211- ) -> List [Document ]:
212- """Return docs selected using the maximal marginal relevance.
213-
214- Maximal marginal relevance optimizes for similarity to query AND diversity
215- among selected documents.
216-
217- Args:
218- query: Text to look up documents similar to.
219- k: Number of Documents to return. Defaults to 4.
220- fetch_k: Number of Documents to fetch to pass to MMR algorithm.
221-
222- Returns:
223- List of Documents selected by maximal marginal relevance.
224- """
225- query_embedding = self .embedding .embed_query (query )
226- query_doc = self .doc_cls (embedding = query_embedding )
227-
228- docs , scores = self .doc_index .find (query_doc , search_field = 'embedding' , limit = fetch_k )
229-
230- embeddings = [emb for emb in docs .emb ]
231-
232- mmr_selected = maximal_marginal_relevance (query_embedding , embeddings , k = k )
233- results = [Document (page_content = self .doc_index [idx ].text ) for idx in mmr_selected ]
234- return results
235-
0 commit comments