llm_qwen.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from transformers import StoppingCriteriaList
  2. import torch
  3. from transformers import TextIteratorStreamer
  4. from plugins.common import settings
  5. from threading import Thread
  6. from transformers import AutoModelForCausalLM, AutoTokenizer
  7. from transformers.generation import GenerationConfig
  8. stopping_text = None
  9. if stopping_text:
  10. stopping_criteria = StoppingCriteriaList(
  11. [lambda input_ids, scores: tokenizer.decode(input_ids[0]).endswith(stopping_text)])
  12. else:
  13. stopping_criteria = []
  14. class ThreadWithReturnValue(Thread):
  15. def run(self):
  16. if self._target is not None:
  17. self._return = self._target(*self._args, **self._kwargs)
  18. def join(self):
  19. super().join()
  20. return self._return
  21. def chat_init(history):
  22. tmp = []
  23. # print(history)
  24. tmp_conver = []
  25. for i, old_chat in enumerate(history):
  26. if old_chat['role'] == "user":
  27. tmp_conver.append(old_chat['content'])
  28. elif old_chat['role'] == "AI":
  29. tmp_conver.append(old_chat['content'])
  30. tmp.append(tmp_conver)
  31. tmp_conver = []
  32. else:
  33. continue
  34. return tmp
  35. def chat_one(prompt, history, max_length, top_p, temperature, data):
  36. model.generation_config.top_p=top_p
  37. model.generation_config.temperature=temperature
  38. model.generation_config.max_new_tokens=max_length
  39. for response in model.chat_stream(tokenizer, prompt, history=history):
  40. yield response
  41. def load_model():
  42. global model, tokenizer
  43. tokenizer = AutoTokenizer.from_pretrained(
  44. settings.llm.path, trust_remote_code=True)
  45. # use bf16
  46. # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
  47. # use fp16
  48. # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
  49. # use cpu only
  50. # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
  51. # use auto mode, automatically select precision based on the device.
  52. model = AutoModelForCausalLM.from_pretrained(
  53. settings.llm.path, device_map="auto", trust_remote_code=True).eval()
  54. model.generation_config = GenerationConfig.from_pretrained(settings.llm.path, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
粤ICP备19079148号