zhishiku_qdrant.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import re
  2. import time
  3. from sentence_transformers import SentenceTransformer
  4. from qdrant_client import QdrantClient
  5. from typing import Dict, List, Optional, Tuple, Union
  6. from plugins.common import settings, allowCROS
  7. from bottle import route, request, static_file
  8. MetadataFilter = Dict[str, Union[str, int, bool]]
  9. COLLECTION_NAME = settings.librarys.qdrant.collection
  10. divider = "\n"
  11. class QdrantIndex(object):
  12. def __init__(self, embedding_model):
  13. if(settings.librarys.qdrant.qdrant_path):
  14. self.qdrant_client = QdrantClient(
  15. path=settings.librarys.qdrant.qdrant_path,
  16. )
  17. elif(settings.librarys.qdrant.qdrant_host):
  18. self.qdrant_client = QdrantClient(
  19. url=settings.librarys.qdrant.qdrant_host,
  20. )
  21. self.embedding_model = embedding_model
  22. self.collection_name = COLLECTION_NAME
  23. def similarity_search_with_score(
  24. self, query, k=settings.librarys.qdrant.count
  25. ):
  26. embedding = self.embedding_model.encode(query)
  27. results = self.qdrant_client.search(
  28. collection_name=self.collection_name,
  29. query_vector=embedding,
  30. with_payload=True,
  31. limit=k,
  32. )
  33. return results
  34. def retrieve_from_id(self, _id):
  35. return self.qdrant_client.retrieve(self.collection_name, [_id])[0]
  36. def find(s, step=0):
  37. try:
  38. original_results = qdrant.similarity_search_with_score(s)
  39. docs = []
  40. for sample in original_results:
  41. if sample.score < settings.librarys.qdrant.similarity_threshold:
  42. continue
  43. docs.append(get_doc(sample, step))
  44. return docs
  45. except Exception as e:
  46. print(e)
  47. return []
  48. def get_doc(doc, step):
  49. final_content = doc.payload["page_content"]
  50. doc_source = doc.payload["metadata"]["source"]
  51. print("文段分数: ", doc.score, final_content)
  52. # 当前文段在对应文档中的分段数
  53. _id = int(doc.id[-3:], 16)
  54. if step > 0:
  55. for i in range(1, step+1):
  56. try:
  57. doc_before = qdrant.retrieve_from_id(doc.id[:-3] + str(hex(_id-i))[2:].zfill(3))
  58. # 可能出现哈希碰撞
  59. if doc_source == doc_before.payload["metadata"]["source"]:
  60. final_content = process_strings(doc_before.payload["page_content"], divider, final_content)
  61. except:
  62. pass
  63. try:
  64. doc_after = qdrant.retrieve_from_id(doc.id[:-3] + str(hex(_id+i))[2:].zfill(3))
  65. # 可能出现哈希碰撞
  66. if doc_source == doc_after.payload["metadata"]["source"]:
  67. final_content = process_strings(final_content, divider, doc_after.payload["page_content"])
  68. except:
  69. pass
  70. if doc_source.endswith(".pdf") or doc_source.endswith(".txt"):
  71. title = f"[{doc_source}](/{settings.librarys.qdrant.path}/{doc_source})"
  72. else:
  73. title = doc_source
  74. return {'title': title, 'content': re.sub(r'\n+', "\n", final_content), "score": doc.score}
  75. def process_strings(A, C, B):
  76. """
  77. find the longest common suffix of A and prefix of B
  78. """
  79. common = ""
  80. for i in range(1, min(len(A), len(B)) + 1):
  81. if A[-i:] == B[:i]:
  82. common = A[-i:]
  83. # if there is a common substring, replace one of them with C and concatenate
  84. if common:
  85. return A[:-len(common)] + C + B
  86. # otherwise, just return A + B
  87. else:
  88. return A + B
  89. embedding_model = SentenceTransformer(settings.librarys.qdrant.model_path, device=settings.librarys.qdrant.device)
  90. qdrant = QdrantIndex(embedding_model)
粤ICP备19079148号