wenda.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  2. import time
  3. from starlette.requests import Request
  4. from fastapi.middleware.wsgi import WSGIMiddleware
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import RedirectResponse
  7. from fastapi.staticfiles import StaticFiles
  8. import uvicorn
  9. import asyncio
  10. import functools
  11. import bottle
  12. from bottle import route, response, request, static_file, hook
  13. import datetime
  14. import json
  15. import os
  16. import threading
  17. import torch
  18. from plugins.common import error_helper, error_print, success_print
  19. from plugins.common import allowCROS
  20. from plugins.common import settings
  21. from plugins.common import app
  22. import logging
  23. logging.captureWarnings(True)
  24. logger = None
  25. try:
  26. from loguru import logger
  27. except:
  28. pass
  29. def load_LLM():
  30. try:
  31. from importlib import import_module
  32. LLM = import_module('llms.llm_'+settings.llm_type)
  33. return LLM
  34. except Exception as e:
  35. logger and logger.exception(e)
  36. print("LLM模型加载失败,请阅读说明:https://github.com/l15y/wenda", e)
  37. LLM = load_LLM()
  38. logging = settings.logging
  39. if logging:
  40. from plugins.defineSQL import session_maker, 记录
  41. model = None
  42. tokenizer = None
  43. def load_model():
  44. LLM.load_model()
  45. torch.cuda.empty_cache()
  46. success_print("模型加载完成")
  47. if __name__ == '__main__':
  48. thread_load_model = threading.Thread(target=load_model)
  49. thread_load_model.start()
  50. zhishiku = None
  51. def load_zsk():
  52. try:
  53. global zhishiku
  54. import plugins.zhishiku as zsk
  55. zhishiku = zsk
  56. success_print("知识库加载完成")
  57. except Exception as e:
  58. logger and logger.exception(e)
  59. error_helper(
  60. "知识库加载失败,请阅读说明", r"https://github.com/l15y/wenda#%E7%9F%A5%E8%AF%86%E5%BA%93")
  61. raise e
  62. if __name__ == '__main__':
  63. thread_load_zsk = threading.Thread(target=load_zsk)
  64. thread_load_zsk.start()
  65. @route('/llm')
  66. def llm_js():
  67. noCache()
  68. return static_file('llm_'+settings.llm_type+".js", root="llms")
  69. @route('/plugins')
  70. def read_auto_plugins():
  71. noCache()
  72. plugins = []
  73. for root, dirs, files in os.walk("autos"):
  74. for file in files:
  75. if(file.endswith(".js")):
  76. file_path = os.path.join(root, file)
  77. with open(file_path, "r", encoding='utf-8') as f:
  78. plugins.append({"name": file, "content": f.read()})
  79. return json.dumps(plugins)
  80. # @route('/writexml', method=("POST","OPTIONS"))
  81. # def writexml():
  82. # data = request.json
  83. # s=json2xml(data).decode("utf-8")
  84. # with open(os.environ['wenda_'+'Config']+"_",'w',encoding = "utf-8") as f:
  85. # f.write(s)
  86. # # print(j)
  87. # return s
  88. def noCache():
  89. response.set_header("Pragma", "no-cache")
  90. response.add_header("Cache-Control", "must-revalidate")
  91. response.add_header("Cache-Control", "no-cache")
  92. response.add_header("Cache-Control", "no-store")
  93. def pathinfo_adjust_wrapper(func):
  94. # A wrapper for _handle() method
  95. @functools.wraps(func)
  96. def _(s, environ):
  97. environ["PATH_INFO"] = environ["PATH_INFO"].encode(
  98. "utf8").decode("latin1")
  99. return func(s, environ)
  100. return _
  101. bottle.Bottle._handle = pathinfo_adjust_wrapper(
  102. bottle.Bottle._handle) # 修复bottle在处理utf8 url时的bug
  103. @hook('before_request')
  104. def validate():
  105. REQUEST_METHOD = request.environ.get('REQUEST_METHOD')
  106. HTTP_ACCESS_CONTROL_REQUEST_METHOD = request.environ.get(
  107. 'HTTP_ACCESS_CONTROL_REQUEST_METHOD')
  108. if REQUEST_METHOD == 'OPTIONS' and HTTP_ACCESS_CONTROL_REQUEST_METHOD:
  109. request.environ['REQUEST_METHOD'] = HTTP_ACCESS_CONTROL_REQUEST_METHOD
  110. waiting_threads = 0
  111. @route('/chat_now', method=('GET', "OPTIONS"))
  112. def api_chat_now():
  113. allowCROS()
  114. noCache()
  115. return {'queue_length': waiting_threads}
  116. @route('/find', method=("POST", "OPTIONS"))
  117. def api_find():
  118. allowCROS()
  119. data = request.json
  120. if not data:
  121. return '0'
  122. prompt = data.get('prompt')
  123. step = data.get('step')
  124. if step is None:
  125. step = int(settings.library.step)
  126. return json.dumps(zhishiku.find(prompt, int(step)))
  127. @route('/completions', method=("POST", "OPTIONS"))
  128. def api_chat_box():
  129. response.content_type = "text/event-stream"
  130. response.add_header("Connection", "keep-alive")
  131. response.add_header("Cache-Control", "no-cache")
  132. response.add_header("X-Accel-Buffering", "no")
  133. data = request.json
  134. messages = data.get('messages')
  135. stream = data.get('stream')
  136. prompt = messages[-1]['content']
  137. data['prompt'] = prompt
  138. history = []
  139. for i, old_chat in enumerate(messages[0:len(messages)-1]):
  140. if old_chat['role'] == "user":
  141. history.append(old_chat)
  142. elif old_chat['role'] == "assistant":
  143. old_chat['role'] = "AI"
  144. history.append(old_chat)
  145. else:
  146. continue
  147. data['history'] = history
  148. data['level'] = 0
  149. from websocket import create_connection
  150. ws = create_connection("ws://127.0.0.1:"+str(settings.port)+"/ws")
  151. ws.send(json.dumps(data))
  152. if not stream:
  153. response.content_type = "application/json"
  154. temp_result = ''
  155. try:
  156. while True:
  157. result = ws.recv()
  158. if len(result) > 0:
  159. temp_result = result
  160. except:
  161. pass
  162. yield json.dumps({"response": temp_result})
  163. else:
  164. try:
  165. while True:
  166. result = ws.recv()
  167. if len(result) > 0:
  168. yield "data: %s\n\n" % json.dumps({"response": result})
  169. except:
  170. pass
  171. yield "data: %s\n\n" % "[DONE]"
  172. ws.close()
  173. @route('/chat_stream', method=("POST", "OPTIONS"))
  174. def api_chat_stream():
  175. allowCROS()
  176. response.add_header("Connection", "keep-alive")
  177. response.add_header("Cache-Control", "no-cache")
  178. response.add_header("X-Accel-Buffering", "no")
  179. data = request.json
  180. data = json.dumps(data)
  181. from websocket import create_connection
  182. ws = create_connection("ws://127.0.0.1:"+str(settings.port)+"/ws")
  183. ws.send(data)
  184. try:
  185. while True:
  186. result = ws.recv()
  187. if len(result) > 0:
  188. yield result
  189. except:
  190. pass
  191. ws.close()
  192. @route('/chat', method=("POST", "OPTIONS"))
  193. def api_chat():
  194. allowCROS()
  195. data = request.json
  196. data = json.dumps(data)
  197. from websocket import create_connection
  198. ws = create_connection("ws://127.0.0.1:"+str(settings.port)+"/ws")
  199. ws.send(data)
  200. try:
  201. while True:
  202. new_result = ws.recv()
  203. if len(new_result) > 0:
  204. result = new_result
  205. except:
  206. pass
  207. ws.close()
  208. print([result])
  209. return result
  210. bottle.debug(True)
  211. @app.middleware("http")
  212. async def add_process_time_header(request: Request, call_next):
  213. start_time = time.time()
  214. response = await call_next(request)
  215. path=request.scope['path']
  216. if path.startswith('/static/') and not path.endswith(".html"):
  217. return response
  218. process_time = time.time() - start_time
  219. response.headers["X-Process-Times"] = str(process_time)
  220. response.headers["Pragma"] = "no-cache"
  221. response.headers["Cache-Control"] = "no-cache,no-store,must-revalidate"
  222. return response
  223. users_count = [0]*4
  224. def get_user_count_before(level):
  225. count = 0
  226. for i in range(level):
  227. count += users_count[i]
  228. return count
  229. class AsyncContextManager:
  230. def __init__(self, level):
  231. self.level = level
  232. async def __aenter__(self):
  233. users_count[self.level] += 1
  234. async def __aexit__(self, exc_type, exc, tb):
  235. users_count[self.level] -= 1
  236. Lock = AsyncContextManager
  237. @app.websocket('/ws')
  238. async def websocket_endpoint(websocket: WebSocket):
  239. global waiting_threads
  240. await websocket.accept()
  241. waiting_threads += 1
  242. # await asyncio.sleep(5)
  243. try:
  244. data = await websocket.receive_json()
  245. prompt = data.get('prompt')
  246. max_length = data.get('max_length')
  247. if max_length is None:
  248. max_length = 2048
  249. top_p = data.get('top_p')
  250. if top_p is None:
  251. top_p = 0.7
  252. temperature = data.get('temperature')
  253. if temperature is None:
  254. temperature = 0.9
  255. keyword = data.get('keyword')
  256. if keyword is None:
  257. keyword = prompt
  258. level = data.get('level')
  259. if level is None:
  260. level = 3
  261. history = data.get('history')
  262. history_formatted = LLM.chat_init(history)
  263. response = ''
  264. IP = websocket.client.host
  265. count_before = get_user_count_before(4)
  266. if count_before >= 4-level:
  267. time2sleep = (count_before+1)*level
  268. while time2sleep > 0:
  269. await websocket.send_text('正在排队,当前计算中用户数:'+str(count_before)+'\n剩余时间:'+str(time2sleep)+"秒")
  270. await asyncio.sleep(1)
  271. count_before = get_user_count_before(4)
  272. if count_before < 4-level:
  273. break
  274. time2sleep -= 1
  275. lock = Lock(level)
  276. async with lock:
  277. print("\033[1;32m"+IP+":\033[1;31m"+prompt+"\033[1;37m")
  278. try:
  279. for response in LLM.chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  280. if (response):
  281. # start = time.time()
  282. await websocket.send_text(response)
  283. await asyncio.sleep(0)
  284. # end = time.time()
  285. # cost+=end-start
  286. except Exception as e:
  287. error = str(e)
  288. await websocket.send_text("错误"+ error)
  289. await websocket.close()
  290. raise e
  291. torch.cuda.empty_cache()
  292. if logging:
  293. with session_maker() as session:
  294. jl = 记录(时间=datetime.datetime.now(),
  295. IP=IP, 问=prompt, 答=response)
  296. session.add(jl)
  297. session.commit()
  298. print(response)
  299. await websocket.close()
  300. except WebSocketDisconnect:
  301. pass
  302. waiting_threads -= 1
  303. @app.get("/")
  304. async def index(request: Request):
  305. return RedirectResponse(url="/index.html")
  306. app.mount(path="/chat/", app=WSGIMiddleware(bottle.app[0]))
  307. app.mount(path="/api/", app=WSGIMiddleware(bottle.app[0]))
  308. app.mount("/txt/", StaticFiles(directory="txt"), name="txt")
  309. app.mount("/", StaticFiles(directory="views"), name="static")
  310. if __name__ == "__main__":
  311. uvicorn.run(app, host="0.0.0.0", port=settings.port,
  312. log_level='error', loop="asyncio")
粤ICP备19079148号