qdrant.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. from __future__ import annotations
  3. import uuid
  4. import warnings
  5. from hashlib import md5
  6. from operator import itemgetter
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Dict,
  12. Iterable,
  13. List,
  14. Optional,
  15. Tuple,
  16. Type,
  17. Union,
  18. )
  19. import numpy as np
  20. from langchain.docstore.document import Document
  21. from langchain.embeddings.base import Embeddings
  22. from langchain.vectorstores import VectorStore
  23. from langchain.vectorstores.utils import maximal_marginal_relevance
  24. if TYPE_CHECKING:
  25. from qdrant_client.http import models as rest
  26. MetadataFilter = Dict[str, Union[str, int, bool, dict, list]]
  27. class Qdrant(object):
  28. """Wrapper around Qdrant vector database.
  29. To use you should have the ``qdrant-client`` package installed.
  30. Example:
  31. .. code-block:: python
  32. from qdrant_client import QdrantClient
  33. from langchain import Qdrant
  34. client = QdrantClient()
  35. collection_name = "MyCollection"
  36. qdrant = Qdrant(client, collection_name, embedding_function)
  37. """
  38. CONTENT_KEY = "page_content"
  39. METADATA_KEY = "metadata"
  40. def __init__(
  41. self,
  42. client: Any,
  43. collection_name: str,
  44. embeddings: Optional[Embeddings] = None,
  45. content_payload_key: str = CONTENT_KEY,
  46. metadata_payload_key: str = METADATA_KEY,
  47. embedding_function: Optional[Callable] = None, # deprecated
  48. ):
  49. """Initialize with necessary components."""
  50. try:
  51. import qdrant_client
  52. except ImportError:
  53. raise ValueError(
  54. "Could not import qdrant-client python package. "
  55. "Please install it with `pip install qdrant-client`."
  56. )
  57. if not isinstance(client, qdrant_client.QdrantClient):
  58. raise ValueError(
  59. f"client should be an instance of qdrant_client.QdrantClient, "
  60. f"got {type(client)}"
  61. )
  62. if embeddings is None and embedding_function is None:
  63. raise ValueError(
  64. "`embeddings` value can't be None. Pass `Embeddings` instance."
  65. )
  66. if embeddings is not None and embedding_function is not None:
  67. raise ValueError(
  68. "Both `embeddings` and `embedding_function` are passed. "
  69. "Use `embeddings` only."
  70. )
  71. self.embeddings = embeddings
  72. self._embeddings_function = embedding_function
  73. self.client: qdrant_client.QdrantClient = client
  74. self.collection_name = collection_name
  75. self.content_payload_key = content_payload_key or self.CONTENT_KEY
  76. self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
  77. if embedding_function is not None:
  78. warnings.warn(
  79. "Using `embedding_function` is deprecated. "
  80. "Pass `Embeddings` instance to `embeddings` instead."
  81. )
  82. if not isinstance(embeddings, Embeddings):
  83. warnings.warn(
  84. "`embeddings` should be an instance of `Embeddings`."
  85. "Using `embeddings` as `embedding_function` which is deprecated"
  86. )
  87. self._embeddings_function = embeddings
  88. self.embeddings = None
  89. def _embed_query(self, query: str) -> List[float]:
  90. """Embed query text.
  91. Used to provide backward compatibility with `embedding_function` argument.
  92. Args:
  93. query: Query text.
  94. Returns:
  95. List of floats representing the query embedding.
  96. """
  97. if self.embeddings is not None:
  98. embedding = self.embeddings.embed_query(query)
  99. else:
  100. if self._embeddings_function is not None:
  101. embedding = self._embeddings_function(query)
  102. else:
  103. raise ValueError("Neither of embeddings or embedding_function is set")
  104. return embedding.tolist() if hasattr(embedding, "tolist") else embedding
  105. def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
  106. """Embed search texts.
  107. Used to provide backward compatibility with `embedding_function` argument.
  108. Args:
  109. texts: Iterable of texts to embed.
  110. Returns:
  111. List of floats representing the texts embedding.
  112. """
  113. if self.embeddings is not None:
  114. embeddings = self.embeddings.embed_documents(list(texts))
  115. if hasattr(embeddings, "tolist"):
  116. embeddings = embeddings.tolist()
  117. elif self._embeddings_function is not None:
  118. embeddings = []
  119. for text in texts:
  120. embedding = self._embeddings_function(text)
  121. if hasattr(embeddings, "tolist"):
  122. embedding = embedding.tolist()
  123. embeddings.append(embedding)
  124. else:
  125. raise ValueError("Neither of embeddings or embedding_function is set")
  126. return embeddings
  127. def add_texts(
  128. self,
  129. texts: Iterable[str],
  130. embeddings,
  131. ids=None,
  132. metadatas: Optional[List[dict]] = None,
  133. **kwargs: Any,
  134. ) -> List[str]:
  135. """Run more texts through the embeddings and add to the vectorstore.
  136. Args:
  137. texts: Iterable of strings to add to the vectorstore.
  138. embeddings: the embeddings of texts
  139. ids: Optional list of ids to associate with the texts. Ids have to be
  140. uuid-like strings.
  141. metadatas: Optional list of metadatas associated with the texts.
  142. Returns:
  143. List of ids from adding the texts into the vectorstore.
  144. """
  145. texts = list(
  146. texts
  147. ) # otherwise iterable might be exhausted after id calculation
  148. if not ids:
  149. ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
  150. self.client.upload_collection(
  151. collection_name=self.collection_name,
  152. vectors=embeddings,
  153. payload=self._build_payloads(
  154. texts, metadatas, self.content_payload_key, self.metadata_payload_key
  155. ),
  156. ids=ids,
  157. parallel=1
  158. )
  159. return ids
  160. def similarity_search(
  161. self,
  162. query: str,
  163. k: int = 4,
  164. filter: Optional[MetadataFilter] = None,
  165. **kwargs: Any,
  166. ) -> List[Document]:
  167. """Return docs most similar to query.
  168. Args:
  169. query: Text to look up documents similar to.
  170. k: Number of Documents to return. Defaults to 4.
  171. filter: Filter by metadata. Defaults to None.
  172. Returns:
  173. List of Documents most similar to the query.
  174. """
  175. results = self.similarity_search_with_score(query, k, filter)
  176. return list(map(itemgetter(0), results))
  177. def similarity_search_with_score(
  178. self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None
  179. ) -> List[Tuple[Document, float]]:
  180. """Return docs most similar to query.
  181. Args:
  182. query: Text to look up documents similar to.
  183. k: Number of Documents to return. Defaults to 4.
  184. filter: Filter by metadata. Defaults to None.
  185. Returns:
  186. List of Documents most similar to the query and score for each.
  187. """
  188. results = self.client.search(
  189. collection_name=self.collection_name,
  190. query_vector=self._embed_query(query),
  191. query_filter=self._qdrant_filter_from_dict(filter),
  192. with_payload=True,
  193. limit=k,
  194. )
  195. return [
  196. (
  197. self._document_from_scored_point(
  198. result, self.content_payload_key, self.metadata_payload_key
  199. ),
  200. result.score,
  201. )
  202. for result in results
  203. ]
  204. def max_marginal_relevance_search(
  205. self,
  206. query: str,
  207. k: int = 4,
  208. fetch_k: int = 20,
  209. lambda_mult: float = 0.5,
  210. **kwargs: Any,
  211. ) -> List[Document]:
  212. """Return docs selected using the maximal marginal relevance.
  213. Maximal marginal relevance optimizes for similarity to query AND diversity
  214. among selected documents.
  215. Args:
  216. query: Text to look up documents similar to.
  217. k: Number of Documents to return. Defaults to 4.
  218. fetch_k: Number of Documents to fetch to pass to MMR algorithm.
  219. Defaults to 20.
  220. lambda_mult: Number between 0 and 1 that determines the degree
  221. of diversity among the results with 0 corresponding
  222. to maximum diversity and 1 to minimum diversity.
  223. Defaults to 0.5.
  224. Returns:
  225. List of Documents selected by maximal marginal relevance.
  226. """
  227. embedding = self._embed_query(query)
  228. results = self.client.search(
  229. collection_name=self.collection_name,
  230. query_vector=embedding,
  231. with_payload=True,
  232. with_vectors=True,
  233. limit=fetch_k,
  234. )
  235. embeddings = [result.vector for result in results]
  236. mmr_selected = maximal_marginal_relevance(
  237. np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
  238. )
  239. return [
  240. self._document_from_scored_point(
  241. results[i], self.content_payload_key, self.metadata_payload_key
  242. )
  243. for i in mmr_selected
  244. ]
  245. @classmethod
  246. def from_texts(
  247. cls: Type[Qdrant],
  248. texts: List[str],
  249. embedding: Embeddings,
  250. embeddings,
  251. ids=None,
  252. metadatas: Optional[List[dict]] = None,
  253. location: Optional[str] = None,
  254. url: Optional[str] = None,
  255. port: Optional[int] = 6333,
  256. grpc_port: int = 6334,
  257. prefer_grpc: bool = False,
  258. https: Optional[bool] = None,
  259. api_key: Optional[str] = None,
  260. prefix: Optional[str] = None,
  261. timeout: Optional[float] = None,
  262. host: Optional[str] = None,
  263. path: Optional[str] = None,
  264. collection_name: Optional[str] = None,
  265. distance_func: str = "Cosine",
  266. content_payload_key: str = CONTENT_KEY,
  267. metadata_payload_key: str = METADATA_KEY,
  268. **kwargs: Any,
  269. ) -> Qdrant:
  270. """Construct Qdrant wrapper from a list of texts.
  271. Args:
  272. texts: A list of texts to be indexed in Qdrant.
  273. embedding: A subclass of `Embeddings`, responsible for text vectorization.
  274. embeddings: the embeddings of the texts.
  275. ids:
  276. Optional list of ids to associate with the texts. Ids have to be
  277. uuid-like strings.
  278. metadatas:
  279. An optional list of metadata. If provided it has to be of the same
  280. length as a list of texts.
  281. location:
  282. If `:memory:` - use in-memory Qdrant instance.
  283. If `str` - use it as a `url` parameter.
  284. If `None` - fallback to relying on `host` and `port` parameters.
  285. url: either host or str of "Optional[scheme], host, Optional[port],
  286. Optional[prefix]". Default: `None`
  287. port: Port of the REST API interface. Default: 6333
  288. grpc_port: Port of the gRPC interface. Default: 6334
  289. prefer_grpc:
  290. If true - use gPRC interface whenever possible in custom methods.
  291. Default: False
  292. https: If true - use HTTPS(SSL) protocol. Default: None
  293. api_key: API key for authentication in Qdrant Cloud. Default: None
  294. prefix:
  295. If not None - add prefix to the REST URL path.
  296. Example: service/v1 will result in
  297. http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
  298. Default: None
  299. timeout:
  300. Timeout for REST and gRPC API requests.
  301. Default: 5.0 seconds for REST and unlimited for gRPC
  302. host:
  303. Host name of Qdrant service. If url and host are None, set to
  304. 'localhost'. Default: None
  305. path:
  306. Path in which the vectors will be stored while using local mode.
  307. Default: None
  308. collection_name:
  309. Name of the Qdrant collection to be used. If not provided,
  310. it will be created randomly. Default: None
  311. distance_func:
  312. Distance function. One of: "Cosine" / "Euclid" / "Dot".
  313. Default: "Cosine"
  314. content_payload_key:
  315. A payload key used to store the content of the document.
  316. Default: "page_content"
  317. metadata_payload_key:
  318. A payload key used to store the metadata of the document.
  319. Default: "metadata"
  320. **kwargs:
  321. Additional arguments passed directly into REST client initialization
  322. This is a user friendly interface that:
  323. 1. Creates embeddings, one for each text
  324. 2. Initializes the Qdrant database as an in-memory docstore by default
  325. (and overridable to a remote docstore)
  326. 3. Adds the text embeddings to the Qdrant database
  327. This is intended to be a quick way to get started.
  328. Example:
  329. .. code-block:: python
  330. from langchain import Qdrant
  331. from langchain.embeddings import OpenAIEmbeddings
  332. embeddings = OpenAIEmbeddings()
  333. qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
  334. """
  335. try:
  336. import qdrant_client
  337. except ImportError:
  338. raise ValueError(
  339. "Could not import qdrant-client python package. "
  340. "Please install it with `pip install qdrant-client`."
  341. )
  342. from qdrant_client.http import models as rest
  343. vector_size = embedding.client.get_sentence_embedding_dimension()
  344. collection_name = collection_name or uuid.uuid4().hex
  345. distance_func = distance_func.upper()
  346. client = qdrant_client.QdrantClient(
  347. location=location,
  348. url=url,
  349. port=port,
  350. grpc_port=grpc_port,
  351. prefer_grpc=prefer_grpc,
  352. https=https,
  353. api_key=api_key,
  354. prefix=prefix,
  355. timeout=timeout,
  356. host=host,
  357. path=path,
  358. **kwargs,
  359. )
  360. #
  361. client.recreate_collection(
  362. collection_name=collection_name,
  363. vectors_config=rest.VectorParams(
  364. size=vector_size,
  365. distance=rest.Distance[distance_func],
  366. on_disk=True
  367. ),
  368. optimizers_config=rest.OptimizersConfigDiff(
  369. indexing_threshold=0, default_segment_number=8,
  370. memmap_threshold=20000
  371. ),
  372. hnsw_config=rest.HnswConfigDiff(on_disk=True),
  373. shard_number=2,
  374. on_disk_payload=True
  375. )
  376. if not ids:
  377. ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]
  378. client.upload_collection(
  379. collection_name=collection_name,
  380. vectors=embeddings,
  381. payload=cls._build_payloads(
  382. texts, metadatas, content_payload_key, metadata_payload_key
  383. ),
  384. ids=ids,
  385. parallel=1
  386. )
  387. return cls(
  388. client=client,
  389. collection_name=collection_name,
  390. embeddings=embedding,
  391. content_payload_key=content_payload_key,
  392. metadata_payload_key=metadata_payload_key,
  393. )
  394. @classmethod
  395. def _build_payloads(
  396. cls,
  397. texts: Iterable[str],
  398. metadatas: Optional[List[dict]],
  399. content_payload_key: str,
  400. metadata_payload_key: str,
  401. ) -> List[dict]:
  402. payloads = []
  403. for i, text in enumerate(texts):
  404. if text is None:
  405. raise ValueError(
  406. "At least one of the texts is None. Please remove it before "
  407. "calling .from_texts or .add_texts on Qdrant instance."
  408. )
  409. metadata = metadatas[i] if metadatas is not None else None
  410. payloads.append(
  411. {
  412. content_payload_key: text,
  413. metadata_payload_key: metadata,
  414. }
  415. )
  416. return payloads
  417. @classmethod
  418. def _document_from_scored_point(
  419. cls,
  420. scored_point: Any,
  421. content_payload_key: str,
  422. metadata_payload_key: str,
  423. ) -> Document:
  424. return Document(
  425. page_content=scored_point.payload.get(content_payload_key),
  426. metadata=scored_point.payload.get(metadata_payload_key) or {},
  427. )
  428. def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
  429. from qdrant_client.http import models as rest
  430. out = []
  431. if isinstance(value, dict):
  432. for _key, value in value.items():
  433. out.extend(self._build_condition(f"{key}.{_key}", value))
  434. elif isinstance(value, list):
  435. for _value in value:
  436. if isinstance(_value, dict):
  437. out.extend(self._build_condition(f"{key}[]", _value))
  438. else:
  439. out.extend(self._build_condition(f"{key}", _value))
  440. else:
  441. out.append(
  442. rest.FieldCondition(
  443. key=f"{self.metadata_payload_key}.{key}",
  444. match=rest.MatchValue(value=value),
  445. )
  446. )
  447. return out
  448. def _qdrant_filter_from_dict(
  449. self, filter: Optional[MetadataFilter]
  450. ) -> Optional[rest.Filter]:
  451. from qdrant_client.http import models as rest
  452. if not filter:
  453. return None
  454. return rest.Filter(
  455. must=[
  456. condition
  457. for key, value in filter.items()
  458. for condition in self._build_condition(key, value)
  459. ]
  460. )
粤ICP备19079148号