gen_data_st.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import sentence_transformers
  2. from langchain.text_splitter import CharacterTextSplitter
  3. from langchain.docstore.document import Document
  4. import threading
  5. import pdfplumber
  6. import re
  7. import chardet
  8. import os
  9. import sys
  10. import time
  11. import docx
  12. sys.path.append(os.getcwd())
  13. from plugins.common import success_print, error_print
  14. from plugins.common import error_helper
  15. from plugins.common import settings
  16. from plugins.common import CounterLock
  17. if settings.librarys.rtst.backend=="Annoy":
  18. from langchain.vectorstores.annoy import Annoy as Vectorstore
  19. else:
  20. from langchain.vectorstores.faiss import FAISS as Vectorstore
  21. source_folder = 'txt'
  22. source_folder_path = os.path.join(os.getcwd(), source_folder)
  23. import logging
  24. logging.basicConfig()
  25. logger = logging.getLogger()
  26. logger.setLevel(logging.ERROR)
  27. root_path_list = source_folder_path.split(os.sep)
  28. docs = []
  29. vectorstore = None
  30. model_path = settings.librarys.rtst.model_path
  31. try:
  32. if model_path.startswith("http"):#"http://127.0.0.1:3000/"
  33. from langchain.embeddings import OpenAIEmbeddings
  34. import os
  35. os.environ["OPENAI_API_TYPE"] = "open_ai"
  36. os.environ["OPENAI_API_BASE"] = model_path
  37. os.environ["OPENAI_API_KEY"] = "your OpenAI key"
  38. from langchain.embeddings.openai import OpenAIEmbeddings
  39. embeddings = OpenAIEmbeddings(
  40. deployment="text-embedding-ada-002",
  41. model="text-embedding-ada-002"
  42. )
  43. else:
  44. from langchain.embeddings import HuggingFaceEmbeddings
  45. embeddings = HuggingFaceEmbeddings(model_name='')
  46. embeddings.client = sentence_transformers.SentenceTransformer(
  47. model_path, device="cuda")
  48. except Exception as e:
  49. error_helper("embedding加载失败",
  50. r"https://github.com/l15y/wenda")
  51. raise e
  52. success_print("Embedding 加载完成")
  53. embedding_lock=CounterLock()
  54. vectorstore_lock=threading.Lock()
  55. def clac_embedding(texts, embeddings, metadatas):
  56. global vectorstore
  57. with embedding_lock:
  58. vectorstore_new = Vectorstore.from_texts(texts, embeddings, metadatas=metadatas)
  59. with vectorstore_lock:
  60. if vectorstore is None:
  61. vectorstore = vectorstore_new
  62. else:
  63. vectorstore.merge_from(vectorstore_new)
  64. def make_index():
  65. global docs
  66. if hasattr(settings.librarys.rtst,"size") and hasattr(settings.librarys.rtst,"overlap"):
  67. text_splitter = CharacterTextSplitter(
  68. chunk_size=int(settings.librarys.rtst.size), chunk_overlap=int(settings.librarys.rtst.overlap), separator='\n')
  69. else:
  70. text_splitter = CharacterTextSplitter(
  71. chunk_size=20, chunk_overlap=0, separator='\n')
  72. doc_texts = text_splitter.split_documents(docs)
  73. docs = []
  74. texts = [d.page_content for d in doc_texts]
  75. metadatas = [d.metadata for d in doc_texts]
  76. thread = threading.Thread(target=clac_embedding, args=(texts, embeddings, metadatas))
  77. thread.start()
  78. while embedding_lock.get_waiting_threads()>2:
  79. time.sleep(0.1)
  80. all_files=[]
  81. for root, dirs, files in os.walk(source_folder_path):
  82. for file in files:
  83. all_files.append([root, file])
  84. success_print("文件列表生成完成",len(all_files))
  85. length_of_read=0
  86. for i in range(len(all_files)):
  87. root, file=all_files[i]
  88. data = ""
  89. title = ""
  90. try:
  91. file_path = os.path.join(root, file)
  92. _, ext = os.path.splitext(file_path)
  93. if ext.lower() == '.pdf':
  94. #pdf
  95. with pdfplumber.open(file_path) as pdf:
  96. data_list = []
  97. for page in pdf.pages:
  98. print(page.extract_text())
  99. data_list.append(page.extract_text())
  100. data = "\n".join(data_list)
  101. elif ext.lower() == '.txt':
  102. # txt
  103. with open(file_path, 'rb') as f:
  104. b = f.read()
  105. result = chardet.detect(b)
  106. with open(file_path, 'r', encoding=result['encoding']) as f:
  107. data = f.read()
  108. elif ext.lower() == '.docx':
  109. doc = docx.Document(file_path)
  110. data_list = []
  111. for para in doc.paragraphs:
  112. data_list.append(para.text)
  113. data = '\n'.join(data_list)
  114. else:
  115. print("目前还不支持文件格式:", ext)
  116. except Exception as e:
  117. print("文件读取失败,当前文件已被跳过:",file,"。错误信息:",e)
  118. # data = re.sub(r'!', "!\n", data)
  119. # data = re.sub(r':', ":\n", data)
  120. # data = re.sub(r'。', "。\n", data)
  121. data = re.sub(r"\n\s*\n", "\n", data)
  122. data = re.sub(r'\r', "\n", data)
  123. data = re.sub(r'\n\n', "\n", data)
  124. length_of_read+=len(data)
  125. docs.append(Document(page_content=data, metadata={"source": file}))
  126. if length_of_read > 1e5:
  127. success_print("处理进度",int(100*i/len(all_files)),f"%\t({i}/{len(all_files)})")
  128. make_index()
  129. # print(embedding_lock.get_waiting_threads())
  130. length_of_read=0
  131. if len(all_files) == 0:
  132. error_print("txt 目录没有数据")
  133. sys.exit(0)
  134. if len(docs) > 0:
  135. make_index()
  136. while embedding_lock.get_waiting_threads()>0:
  137. time.sleep(0.1)
  138. success_print("处理进度",100,"%")
  139. with embedding_lock:
  140. time.sleep(0.1)
  141. with vectorstore_lock:
  142. success_print("处理完成")
  143. try:
  144. vectorstore_old = Vectorstore.load_local(
  145. 'memory/default', embeddings=embeddings)
  146. success_print("合并至已有索引。如不需合并请删除 memory/default 文件夹")
  147. vectorstore_old.merge_from(vectorstore)
  148. vectorstore_old.save_local('memory/default')
  149. except:
  150. print("新建索引")
  151. vectorstore.save_local('memory/default')
  152. success_print("保存完成")
粤ICP备19079148号