gen_data_qdrant.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import re
  2. import os
  3. import sys
  4. import math
  5. import threading
  6. import loguru
  7. from hashlib import md5
  8. os.chdir(sys.path[0][:-8])
  9. from langchain.text_splitter import CharacterTextSplitter
  10. from qdrant_client import QdrantClient
  11. from qdrant_client.http import models as rest
  12. from langchain.docstore.document import Document
  13. from typing import Dict, Iterable, List, Optional, Union
  14. from langchain.embeddings import HuggingFaceEmbeddings
  15. import logging, time
  16. logging.basicConfig()
  17. logger = logging.getLogger()
  18. logger.setLevel(logging.ERROR)
  19. import chardet
  20. import pdfplumber
  21. from qdrant import Qdrant
  22. from common import CounterLock
  23. from common import settings, error_print, error_helper, success_print
  24. source_folder = settings.librarys.qdrant.path
  25. source_folder_path = os.path.join(os.getcwd(), source_folder)
  26. root_path_list = source_folder_path.split(os.sep)
  27. docs = []
  28. texts_count = 0
  29. MetadataFilter = Dict[str, Union[str, int, bool]]
  30. COLLECTION_NAME = settings.librarys.qdrant.collection # 向量库名字
  31. model_path = settings.librarys.qdrant.model_path
  32. try:
  33. encode_kwargs = {'batch_size': settings.librarys.qdrant.batch_size}
  34. model_kwargs = {"device": settings.librarys.qdrant.device}
  35. embedding = HuggingFaceEmbeddings(model_name=model_path, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs)
  36. except Exception as e:
  37. error_helper("embedding加载失败,请下载相应模型",
  38. r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")
  39. raise e
  40. success_print("Embedding model加载完成")
  41. try:
  42. client = QdrantClient(path="memory/q")
  43. client.get_collection(COLLECTION_NAME)
  44. vectorstore = Qdrant(client, COLLECTION_NAME, embedding)
  45. except:
  46. del client
  47. vectorstore = None
  48. # vectorstore = None
  49. embedding_lock = CounterLock()
  50. vectorstore_lock = threading.Lock()
  51. def clac_embedding(texts, embedding, metadatas):
  52. global vectorstore
  53. with embedding_lock:
  54. embeddings = embedding.embed_documents(texts)
  55. with vectorstore_lock:
  56. ids = gen_ids(metadatas)
  57. if vectorstore is None:
  58. # 如需插入大规模数据可以将prefer_grpc参数置为True
  59. if(settings.librarys.qdrant.qdrant_path):
  60. vectorstore = Qdrant.from_texts(texts, embedding, embeddings, ids, metadatas=metadatas,
  61. path=settings.librarys.qdrant.qdrant_path, prefer_grpc=True,
  62. collection_name=settings.librarys.qdrant.collection, timeout=10)
  63. elif(settings.librarys.qdrant.qdrant_host):
  64. vectorstore = Qdrant.from_texts(texts, embedding, embeddings, ids, metadatas=metadatas,
  65. url=settings.librarys.qdrant.qdrant_host, prefer_grpc=True,
  66. collection_name=settings.librarys.qdrant.collection, timeout=10)
  67. else:
  68. vectorstore.add_texts(texts, embeddings, ids, metadatas)
  69. # 生成该id的方法仅供参考
  70. def gen_ids(metadatas):
  71. ids = []
  72. same_title_count = 0
  73. last_text_title = ""
  74. for metadata in metadatas:
  75. text_title = md5(metadata["source"].encode("utf-8")).hexdigest()
  76. if last_text_title != text_title:
  77. last_text_title = text_title
  78. same_title_count = 0
  79. else:
  80. same_title_count += 1
  81. origin = text_title[:30] + str(hex(same_title_count))[2:].zfill(3) # 最后三位为十六进制的文章段落数 前二十九位为文章title哈希
  82. origin = f"{origin[:8]}-{origin[8:12]}-{origin[12:16]}-{origin[16:20]}-{origin[-12:]}"
  83. ids.append(origin)
  84. return ids
  85. def make_index():
  86. global docs, texts_count
  87. if hasattr(settings.librarys.qdrant, "size") and hasattr(settings.librarys.qdrant, "overlap"):
  88. text_splitter = CharacterTextSplitter(
  89. chunk_size=int(settings.librarys.qdrant.size), chunk_overlap=int(settings.librarys.qdrant.overlap), separator='\n')
  90. else:
  91. text_splitter = CharacterTextSplitter(
  92. chunk_size=20, chunk_overlap=0, separator='\n')
  93. doc_texts = text_splitter.split_documents(docs)
  94. docs = []
  95. texts = [d.page_content for d in doc_texts]
  96. metadatas = [d.metadata for d in doc_texts]
  97. texts_count += len(texts)
  98. thread = threading.Thread(target=clac_embedding, args=(texts, embedding, metadatas))
  99. thread.start()
  100. while embedding_lock.get_waiting_threads() > 1:
  101. time.sleep(0.1)
  102. all_files = []
  103. for root, dirs, files in os.walk(source_folder_path):
  104. for file in files:
  105. all_files.append([root, file])
  106. success_print("文件列表生成完成", len(all_files))
  107. length_of_read = 0
  108. for i in range(len(all_files)):
  109. root, file = all_files[i]
  110. data = ""
  111. title = ""
  112. try:
  113. file_path = os.path.join(root, file)
  114. _, ext = os.path.splitext(file_path)
  115. if ext.lower() == '.pdf':
  116. # pdf
  117. with pdfplumber.open(file_path) as pdf:
  118. data_list = []
  119. for page in pdf.pages:
  120. print(page.extract_text())
  121. data_list.append(page.extract_text())
  122. data = "\n".join(data_list)
  123. elif ext.lower() == '.txt':
  124. # txt
  125. with open(file_path, 'rb') as f:
  126. print("open:",file_path)
  127. b = f.read()
  128. result = chardet.detect(b)
  129. with open(file_path, 'r', encoding=result['encoding']) as f:
  130. data = f.read()
  131. else:
  132. print("目前还不支持文件格式:", ext)
  133. except Exception as e:
  134. print("文件读取失败,当前文件已被跳过:", file, "。错误信息:", e)
  135. data = re.sub(r'!', "!\n", data)
  136. data = re.sub(r':', ":\n", data)
  137. data = re.sub(r'。', "。\n", data)
  138. data = re.sub(r'\r', "\n", data)
  139. data = re.sub(r'\n\n', "\n", data)
  140. data = re.sub(r"\n\s*\n", "\n", data)
  141. length_of_read += len(data)
  142. docs.append(Document(page_content=data, metadata={"source": file}))
  143. if length_of_read > 1e5: # 大于10万字的先处理(即不作为最后的统一处理)
  144. success_print("处理进度", int(100*i/len(all_files)), f"%\t({i}/{len(all_files)})")
  145. make_index()
  146. length_of_read = 0
  147. length_of_read += len(data)
  148. docs.append(Document(page_content=data, metadata={"source": file}))
  149. if length_of_read > 1e5:
  150. success_print("处理进度", int(100 * i / len(all_files)), f"%\t({i}/{len(all_files)})")
  151. make_index()
  152. length_of_read = 0
  153. if len(all_files) == 0:
  154. error_print("指定目录{}没有数据".format(settings.librarys.qdrant.path))
  155. sys.exit(0)
  156. if len(docs) > 0:
  157. make_index()
  158. while embedding_lock.get_waiting_threads() > 0:
  159. time.sleep(0.1)
  160. with embedding_lock:
  161. time.sleep(0.1)
  162. success_print("数据上装完成")
  163. with vectorstore_lock:
  164. print("开始构建索引,需要一定时间")
  165. vectorstore.client.update_collection(
  166. collection_name=COLLECTION_NAME,
  167. optimizer_config=rest.OptimizersConfigDiff(
  168. indexing_threshold=20000
  169. )
  170. )
  171. success_print("索引处理完成")
粤ICP备19079148号