llm_baichuan.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import torch
  2. from transformers import TextIteratorStreamer
  3. from plugins.common import settings
  4. from threading import Thread
  5. user = settings.llm.user
  6. answer = settings.llm.answer
  7. interface =settings.llm.interface
  8. gptq = False
  9. from transformers import StoppingCriteria, StoppingCriteriaList
  10. stopping_criteria_text="\nHuman:"
  11. if stopping_criteria_text:
  12. stopping_criteria = StoppingCriteriaList([lambda input_ids, scores: tokenizer.decode(input_ids[0]).endswith(stopping_criteria_text)])
  13. else:
  14. stopping_criteria=[]
  15. if settings.llm.path.lower().find("gptq") > -1:
  16. print("gptq mode!")
  17. gptq = True
  18. class ThreadWithReturnValue(Thread):
  19. def run(self):
  20. if self._target is not None:
  21. self._return = self._target(*self._args, **self._kwargs)
  22. def join(self):
  23. super().join()
  24. return self._return
  25. def chat_init(history):
  26. tmp = []
  27. # print(history)
  28. for i, old_chat in enumerate(history):
  29. if old_chat['role'] == "user":
  30. tmp.append(f"{user}{interface}"+old_chat['content'])
  31. elif old_chat['role'] == "AI":
  32. tmp.append(f"{answer}{interface}"+old_chat['content'])
  33. else:
  34. continue
  35. history = '\n'.join(tmp)
  36. return history
  37. def chat_one(prompt, history, max_length, top_p, temperature, data):
  38. if prompt.startswith("raw!"):
  39. print("[raw mode]", end="")
  40. prompt = prompt.replace("raw!", "")
  41. else:
  42. prompt = f"{user}{interface}{prompt}\n{answer}{interface}"
  43. if history is None:
  44. history = ""
  45. else:
  46. history += '\n'
  47. prompt = history+prompt
  48. inputs = tokenizer(prompt, return_tensors='pt')
  49. if gptq:
  50. inputs = inputs.input_ids.cuda()
  51. else:
  52. inputs = inputs.to('cuda:0')
  53. yield str(len(prompt))+'字正在计算'
  54. streamer = TextIteratorStreamer(tokenizer, skip_prompt=True,timeout=5)
  55. if gptq:
  56. thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
  57. inputs=inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1,stopping_criteria=stopping_criteria, streamer=streamer))
  58. else:
  59. thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
  60. inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1, streamer=streamer))
  61. thread.start()
  62. generated_text = ""
  63. for new_text in streamer:
  64. if new_text != '':
  65. generated_text += new_text
  66. yield generated_text.removesuffix("</s>")
  67. pred = thread.join()
  68. output=tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(prompt):]
  69. if stopping_criteria_text:
  70. if output.endswith(stopping_criteria_text):
  71. output = output[:-len(stopping_criteria_text)]
  72. yield output
  73. def load_model():
  74. global model, tokenizer
  75. from transformers import AutoModelForCausalLM, AutoTokenizer
  76. if gptq:
  77. from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
  78. tokenizer = AutoTokenizer.from_pretrained(
  79. settings.llm.path, use_fast=True)
  80. model = AutoGPTQForCausalLM.from_quantized(settings.llm.path,
  81. model_basename=settings.llm.basename,
  82. use_safetensors=True,
  83. trust_remote_code=True,
  84. device="cuda:0",
  85. use_triton=False,
  86. quantize_config=None)
  87. else:
  88. tokenizer = AutoTokenizer.from_pretrained(
  89. settings.llm.path, trust_remote_code=True, revision="1")
  90. model = AutoModelForCausalLM.from_pretrained(
  91. settings.llm.path, trust_remote_code=True,
  92. low_cpu_mem_usage=True,
  93. torch_dtype=torch.float16,
  94. revision="1")
  95. if not (settings.llm.lora == '' or settings.llm.lora == None):
  96. print('Lora模型地址', settings.llm.lora)
  97. from peft import PeftModel
  98. model = PeftModel.from_pretrained(
  99. model, settings.llm.lora, adapter_name=settings.llm.lora)
  100. if settings.llm.path.lower().find("13b"):
  101. model=model.quantize(8)
  102. model = model.cuda()
  103. model = model.eval()
  104. if not (settings.llm.lora == '' or settings.llm.lora == None):
  105. from bottle import route, response, request
  106. @route('/lora_load_adapter', method=("POST", "OPTIONS"))
  107. def load_adapter():
  108. # allowCROS()
  109. try:
  110. data = request.json
  111. lora_path = data.get("lora_path")
  112. adapter_name = data.get("adapter_name")
  113. model.load_adapter(lora_path, adapter_name=adapter_name)
  114. return "保存成功"
  115. except Exception as e:
  116. return str(e)
  117. @route('/lora_set_adapter', method=("POST", "OPTIONS"))
  118. def set_adapter():
  119. # allowCROS()
  120. try:
  121. data = request.json
  122. adapter_name = data.get("adapter_name")
  123. model.set_adapter(adapter_name)
  124. return "保存成功"
  125. except Exception as e:
  126. return str(e)
粤ICP备19079148号