YuanAPI.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import threading
  2. import datetime
  3. from bottle import route, response, request, static_file
  4. import bottle
  5. logging = False
  6. if logging:
  7. from defineSQL import session_maker, 记录
  8. mutex = threading.Lock()
  9. @route('/static/:name')
  10. def staticjs(name='-'):
  11. return static_file(name, root="views\static")
  12. @route('/:name')
  13. def static(name='-'):
  14. return static_file(name, root="views")
  15. @route('/')
  16. def index():
  17. return static_file("index.html", root="views")
  18. 当前用户 = ['模型加载中', '', '']
  19. @route('/api/chat_now', method='GET')
  20. def api_chat_now():
  21. return '当前用户:'+当前用户[0]+"\n问题:"+当前用户[1]+"\n回答:"+当前用户[2]+''
  22. @route('/api/chat_stream', method='POST')
  23. def api_chat_stream():
  24. data = request.json
  25. prompt = data.get('prompt')
  26. max_length = data.get('max_length')
  27. if max_length is None:
  28. max_length = 2048
  29. top_p = data.get('top_p')
  30. if top_p is None:
  31. top_p = 0.7
  32. temperature = data.get('temperature')
  33. if temperature is None:
  34. temperature = 0.9
  35. history_formatted = None
  36. response = ''
  37. # print(request.environ)
  38. IP = request.environ.get(
  39. 'HTTP_X_REAL_IP') or request.environ.get('REMOTE_ADDR')
  40. global 当前用户
  41. with mutex:
  42. yield str(len(prompt))+'字正在计算///'
  43. try:
  44. input_text = "用户:" + prompt + "\n小元:"
  45. response = answer(input_text)
  46. except Exception as e:
  47. # pass
  48. print("错误", str(e), e)
  49. yield response+'///'
  50. if logging:
  51. with session_maker() as session:
  52. jl = 记录(时间=datetime.datetime.now(), IP=IP, 问=prompt, 答=response)
  53. session.add(jl)
  54. session.commit()
  55. print(f"\033[1;32m{IP}:\033[1;31m{prompt}\033[1;37m\n{response}")
  56. yield "/././"
  57. model = None
  58. tokenizer = None
  59. device = None
  60. def preprocess(text):
  61. text = text.replace("\n", "\\n").replace("\t", "\\t")
  62. return text
  63. def postprocess(text):
  64. return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20', ' ')
  65. def answer(text, sample=True, top_p=1, temperature=0.7):
  66. '''sample:是否抽样。生成任务,可以设置为True;
  67. top_p:0-1之间,生成的内容越多样'''
  68. text = preprocess(text)
  69. encoding = tokenizer(text=[text], truncation=True, padding=True,
  70. max_length=768, return_tensors="pt").to(device)
  71. if not sample:
  72. out = model.generate(**encoding, return_dict_in_generate=True,
  73. output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
  74. else:
  75. out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512,
  76. do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=12)
  77. out_text = tokenizer.batch_decode(
  78. out["sequences"], skip_special_tokens=True)
  79. return postprocess(out_text[0])
  80. def load_model():
  81. global model, tokenizer, device
  82. mutex.acquire()
  83. from transformers import T5Tokenizer, T5ForConditionalGeneration
  84. tokenizer = T5Tokenizer.from_pretrained(
  85. "ChatYuan-large-v2", local_files_only=True)
  86. model = T5ForConditionalGeneration.from_pretrained(
  87. "ChatYuan-large-v2", local_files_only=True).half()
  88. import torch
  89. device = torch.device('cuda')
  90. model.to(device)
  91. mutex.release()
  92. print("模型加载完成")
  93. thread_load_model = threading.Thread(target=load_model)
  94. thread_load_model.start()
  95. # bottle.debug(True)
  96. bottle.run(server='paste', port=17860, quiet=True)
粤ICP备19079148号