llm_generic_transformers.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from transformers import StoppingCriteriaList
  2. import torch
  3. from transformers import TextIteratorStreamer
  4. from plugins.common import settings
  5. from threading import Thread
  6. from transformers import AutoModelForCausalLM, AutoTokenizer
  7. user = settings.llm.user
  8. answer = settings.llm.answer
  9. interface = settings.llm.interface
  10. stopping_text = settings.llm.stopping_text
  11. if stopping_text:
  12. stopping_criteria = StoppingCriteriaList(
  13. [lambda input_ids, scores: tokenizer.decode(input_ids[0]).endswith(stopping_text)])
  14. else:
  15. stopping_criteria = []
  16. if settings.llm.path.lower().find("gptq") > -1:
  17. print("gptq mode!")
  18. gptq = True
  19. else:
  20. gptq = False
  21. class ThreadWithReturnValue(Thread):
  22. def run(self):
  23. if self._target is not None:
  24. self._return = self._target(*self._args, **self._kwargs)
  25. def join(self):
  26. super().join()
  27. return self._return
  28. def chat_init(history):
  29. tmp = []
  30. # print(history)
  31. for i, old_chat in enumerate(history):
  32. if old_chat['role'] == "user":
  33. tmp.append(f"{user}{interface}"+old_chat['content'])
  34. elif old_chat['role'] == "AI":
  35. tmp.append(f"{answer}{interface}"+old_chat['content'])
  36. else:
  37. continue
  38. history = '\n'.join(tmp)
  39. return history
  40. def chat_one(prompt, history, max_length, top_p, temperature, data):
  41. if prompt.startswith("raw!"):
  42. print("[raw mode]", end="")
  43. prompt = prompt.replace("raw!", "")
  44. else:
  45. prompt = f"{user}{interface}{prompt}\n{answer}{interface}"
  46. if history is None:
  47. history = ""
  48. else:
  49. history += '\n'
  50. prompt = history+prompt
  51. inputs = tokenizer(prompt, return_tensors='pt')
  52. if gptq:
  53. inputs = inputs.input_ids.cuda()
  54. else:
  55. inputs = inputs.to('cuda:0')
  56. yield str(len(prompt))+'字正在计算'
  57. streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=5)
  58. if gptq:
  59. thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
  60. inputs=inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1, stopping_criteria=stopping_criteria, streamer=streamer))
  61. else:
  62. thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
  63. inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1, stopping_criteria=stopping_criteria, streamer=streamer))
  64. thread.start()
  65. generated_text = ""
  66. for new_text in streamer:
  67. if new_text != '':
  68. generated_text += new_text
  69. yield generated_text.removesuffix("</s>")
  70. pred = thread.join()
  71. output = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[
  72. len(prompt):]
  73. if stopping_text:
  74. if output.endswith(stopping_text):
  75. output = output[:-len(stopping_text)]
  76. yield output
  77. def load_model():
  78. global model, tokenizer
  79. if gptq:
  80. from auto_gptq import AutoGPTQForCausalLM
  81. tokenizer = AutoTokenizer.from_pretrained(
  82. settings.llm.path, use_fast=True)
  83. model = AutoGPTQForCausalLM.from_quantized(settings.llm.path,
  84. model_basename=settings.llm.basename,
  85. use_safetensors=True,
  86. trust_remote_code=True,
  87. device="cuda:0",
  88. use_triton=False,
  89. quantize_config=None)
  90. else:
  91. tokenizer = AutoTokenizer.from_pretrained(
  92. settings.llm.path, trust_remote_code=True, revision="1")
  93. model = AutoModelForCausalLM.from_pretrained(
  94. settings.llm.path, trust_remote_code=True,
  95. low_cpu_mem_usage=True,
  96. torch_dtype=torch.float16,
  97. device_map="cuda:0",
  98. revision="1")
  99. if not (settings.llm.lora == '' or settings.llm.lora == None):
  100. print('Lora模型地址', settings.llm.lora)
  101. from peft import PeftModel
  102. model = PeftModel.from_pretrained(
  103. model, settings.llm.lora, adapter_name=settings.llm.lora)
  104. if settings.llm.path.lower().find("13b")!=-1:
  105. model = model.quantize(8)
  106. model = model.cuda()
  107. model = model.eval()
  108. if not (settings.llm.lora == '' or settings.llm.lora == None):
  109. from bottle import route, request
  110. @route('/lora_load_adapter', method=("POST", "OPTIONS"))
  111. def load_adapter():
  112. # allowCROS()
  113. try:
  114. data = request.json
  115. lora_path = data.get("lora_path")
  116. adapter_name = data.get("adapter_name")
  117. model.load_adapter(lora_path, adapter_name=adapter_name)
  118. return "保存成功"
  119. except Exception as e:
  120. return str(e)
  121. @route('/lora_set_adapter', method=("POST", "OPTIONS"))
  122. def set_adapter():
  123. # allowCROS()
  124. try:
  125. data = request.json
  126. adapter_name = data.get("adapter_name")
  127. model.set_adapter(adapter_name)
  128. return "保存成功"
  129. except Exception as e:
  130. return str(e)
粤ICP备19079148号