zhishiku_rtst.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from langchain.embeddings import HuggingFaceEmbeddings
  2. import sentence_transformers
  3. import numpy as np
  4. import re,os
  5. from plugins.common import settings,allowCROS
  6. from plugins.common import error_helper
  7. from plugins.common import success_print
  8. if settings.librarys.rtst.backend=="Annoy":
  9. from langchain.vectorstores.annoy import Annoy as Vectorstore
  10. else:
  11. from langchain.vectorstores.faiss import FAISS as Vectorstore
  12. divider='\n'
  13. if not os.path.exists('memory'):
  14. os.mkdir('memory')
  15. cunnrent_setting=settings.librarys.rtst
  16. def get_doc_by_id(id,memory_name):
  17. return vectorstores[memory_name].docstore.search(vectorstores[memory_name].index_to_docstore_id[id])
  18. def process_strings(A, C, B):
  19. # find the longest common suffix of A and prefix of B
  20. common = ""
  21. for i in range(1, min(len(A), len(B)) + 1):
  22. if A[-i:] == B[:i]:
  23. common = A[-i:]
  24. # if there is a common substring, replace one of them with C and concatenate
  25. if common:
  26. return A[:-len(common)] + C + B
  27. # otherwise, just return A + B
  28. else:
  29. return A + B
  30. def get_title_by_doc(doc):
  31. return re.sub('【.+】', '', doc.metadata['source'])
  32. def get_doc(id,score,step,memory_name):
  33. doc = get_doc_by_id(id,memory_name)
  34. final_content=doc.page_content
  35. # print("文段分数:",score,[doc.page_content])
  36. if step > 0:
  37. for i in range(1, step+1):
  38. try:
  39. doc_before=get_doc_by_id(id-i,memory_name)
  40. if get_title_by_doc(doc_before)==get_title_by_doc(doc):
  41. final_content=process_strings(doc_before.page_content,divider,final_content)
  42. # print("上文分数:",score,doc.page_content)
  43. except:
  44. pass
  45. try:
  46. doc_after=get_doc_by_id(id+i,memory_name)
  47. if get_title_by_doc(doc_after)==get_title_by_doc(doc):
  48. final_content=process_strings(final_content,divider,doc_after.page_content)
  49. except:
  50. pass
  51. if doc.metadata['source'].endswith(".pdf") or doc.metadata['source'].endswith(".txt"):
  52. title=f"[{doc.metadata['source']}](/txt/{doc.metadata['source']})"
  53. else:
  54. title=doc.metadata['source']
  55. return {'title': title,'content':re.sub(r'\n+', "\n", final_content),"score":int(score)}
  56. def find(s,step = 0,memory_name="default"):
  57. try:
  58. embedding = get_vectorstore(memory_name).embedding_function(s)
  59. scores, indices = vectorstores[memory_name].index.search(np.array([embedding], dtype=np.float32), int(cunnrent_setting.count))
  60. docs = []
  61. for j, i in enumerate(indices[0]):
  62. if i == -1:
  63. continue
  64. if scores[0][j]>260:continue
  65. docs.append(get_doc(i,scores[0][j],step,memory_name))
  66. return docs
  67. except Exception as e:
  68. print(e)
  69. return []
  70. model_path=cunnrent_setting.model_path
  71. try:
  72. if model_path.startswith("http"):#"http://127.0.0.1:3000/"
  73. from langchain.embeddings import OpenAIEmbeddings
  74. import os
  75. os.environ["OPENAI_API_TYPE"] = "open_ai"
  76. os.environ["OPENAI_API_BASE"] = model_path
  77. os.environ["OPENAI_API_KEY"] = "your OpenAI key"
  78. from langchain.embeddings.openai import OpenAIEmbeddings
  79. embeddings = OpenAIEmbeddings(
  80. deployment="text-embedding-ada-002",
  81. model="text-embedding-ada-002"
  82. )
  83. else:
  84. from langchain.embeddings import HuggingFaceEmbeddings
  85. embeddings = HuggingFaceEmbeddings(model_name='')
  86. embeddings.client = sentence_transformers.SentenceTransformer(
  87. model_path, device="cuda")
  88. except Exception as e:
  89. error_helper("embedding加载失败",
  90. r"https://github.com/l15y/wenda")
  91. raise e
  92. vectorstores={}
  93. def get_vectorstore(memory_name):
  94. try:
  95. return vectorstores[memory_name]
  96. except Exception as e:
  97. try:
  98. vectorstores[memory_name] = Vectorstore.load_local(
  99. 'memory/'+memory_name, embeddings=embeddings)
  100. return vectorstores[memory_name]
  101. except Exception as e:
  102. success_print("没有读取到RTST记忆区%s,将新建。"%memory_name)
  103. return None
  104. from langchain.docstore.document import Document
  105. from langchain.text_splitter import CharacterTextSplitter
  106. from bottle import route, response, request, static_file, hook
  107. import bottle
  108. @route('/upload_rtst_zhishiku', method=("POST","OPTIONS"))
  109. def upload_zhishiku():
  110. allowCROS()
  111. try:
  112. data = request.json
  113. title=data.get("title")
  114. memory_name=data.get("memory_name")
  115. data = data.get("txt")
  116. # data = re.sub(r'!', "!\n", data)
  117. # data = re.sub(r':', ":\n", data)
  118. # data = re.sub(r'。', "。\n", data)
  119. data = re.sub(r"\n\s*\n", "\n", data)
  120. data = re.sub(r'\r', "\n", data)
  121. data = re.sub(r'\n\n', "\n", data)
  122. docs=[Document(page_content=data, metadata={"source":title })]
  123. print(docs)
  124. if hasattr(settings.librarys.rtst,"size") and hasattr(settings.librarys.rtst,"overlap"):
  125. text_splitter = CharacterTextSplitter(
  126. chunk_size=int(settings.librarys.rtst.size), chunk_overlap=int(settings.librarys.rtst.overlap), separator='\n')
  127. else:
  128. text_splitter = CharacterTextSplitter(
  129. chunk_size=20, chunk_overlap=0, separator='\n')
  130. doc_texts = text_splitter.split_documents(docs)
  131. texts = [d.page_content for d in doc_texts]
  132. metadatas = [d.metadata for d in doc_texts]
  133. vectorstore_new = Vectorstore.from_texts(texts, embeddings, metadatas=metadatas)
  134. vectorstore=get_vectorstore(memory_name)
  135. if vectorstore is None:
  136. vectorstores[memory_name]=vectorstore_new
  137. else:
  138. vectorstores[memory_name].merge_from(vectorstore_new)
  139. return '成功'
  140. except Exception as e:
  141. return str(e)
  142. @route('/save_rtst_zhishiku', method=("POST","OPTIONS"))
  143. def save_zhishiku():
  144. allowCROS()
  145. try:
  146. data = request.json
  147. memory_name=data.get("memory_name")
  148. vectorstores[memory_name].save_local('memory/'+memory_name)
  149. return "保存成功"
  150. except Exception as e:
  151. return str(e)
  152. import json
  153. @route('/find_rtst_in_memory', method=("POST","OPTIONS"))
  154. def api_find():
  155. allowCROS()
  156. try:
  157. data = request.json
  158. prompt = data.get('prompt')
  159. step = data.get('step')
  160. memory_name=data.get("memory_name")
  161. if step is None:
  162. step = int(settings.library.step)
  163. return json.dumps(find(prompt,int(step),memory_name))
  164. except Exception as e:
  165. return str(e)
  166. @route('/list_rtst_in_disk', method=("POST","OPTIONS"))
  167. def api_find():
  168. allowCROS()
  169. return json.dumps(os.listdir('memory'))
  170. @route('/del_rtst_in_memory', method=("POST","OPTIONS"))
  171. def api_find():
  172. allowCROS()
  173. try:
  174. data = request.json
  175. memory_name=data.get("memory_name")
  176. del vectorstores[memory_name]
  177. except Exception as e:
  178. return str(e)
  179. @route('/save_news', method=("POST","OPTIONS"))
  180. def save_news():
  181. allowCROS()
  182. try:
  183. data = request.json
  184. if not data:
  185. return 'no data'
  186. title = data.get('title')
  187. txt = data.get('txt')
  188. cut_file = f"txt/{title}.txt"
  189. with open(cut_file, 'w', encoding='utf-8') as f:
  190. f.write(txt)
  191. f.close()
  192. return 'success'
  193. except Exception as e:
  194. return(e)
粤ICP备19079148号