llm_replitcode.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from plugins.common import settings
  2. def chat_init(history):
  3. return []
  4. def chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  5. yield str(len(prompt))+'字正在计算'
  6. x = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
  7. y = model.generate(x, max_length=200, do_sample=True, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
  8. # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
  9. response = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
  10. yield response
  11. def load_model():
  12. global model, tokenizer
  13. from transformers import AutoModelForCausalLM, AutoTokenizer
  14. import torch
  15. tokenizer = AutoTokenizer.from_pretrained(
  16. settings.llm.path, local_files_only=True, trust_remote_code=True)
  17. model = AutoModelForCausalLM.from_pretrained(
  18. settings.llm.path, local_files_only=True, trust_remote_code=True)
  19. device, precision = settings.llm.strategy.split()
  20. # 根据设备执行不同的操作
  21. if device == 'cpu':
  22. # 如果是cpu,不做任何操作
  23. pass
  24. elif device == 'cuda':
  25. # 如果是gpu,把模型移动到显卡
  26. import torch
  27. # 根据精度执行不同的操作
  28. if precision == 'fp16':
  29. # 如果是fp16,把模型转化为半精度
  30. model = model.half()
  31. elif precision == 'fp32':
  32. # 如果是fp32,把模型转化为全精度
  33. model = model.float()
  34. elif precision.startswith('fp16i'):
  35. # 如果是fp16i开头,把模型转化为指定的精度
  36. # 从字符串中提取精度的数字部分
  37. bits = int(precision[5:])
  38. # 调用quantize方法,传入精度参数
  39. model = model.quantize(bits)
  40. model = model.half()
  41. elif precision.startswith('fp32i'):
  42. # 如果是fp32i开头,把模型转化为指定的精度
  43. # 从字符串中提取精度的数字部分
  44. bits = int(precision[5:])
  45. # 调用quantize方法,传入精度参数
  46. model = model.quantize(bits)
  47. model = model.float()
  48. else:
  49. # 如果是其他精度,报错并退出程序
  50. print('Error: 不受支持的精度')
  51. exit()
  52. model = model.to(torch.device("cuda"))
  53. else:
  54. # 如果是其他设备,报错并退出程序
  55. print('Error: 不受支持的设备')
  56. exit()
  57. model = model.eval()
粤ICP备19079148号