llm_glm6b.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from plugins.common import settings
  2. import json
  3. chatglm3_mode =settings.llm.path.lower().find("chatglm3-6b") > -1
  4. print('chatglm3_mode',chatglm3_mode)
  5. def chat_init(history):
  6. history_formatted = []
  7. if history is not None:
  8. tmp = []
  9. for i, old_chat in enumerate(history):
  10. if len(tmp) == 0 and old_chat['role'] == "user":
  11. if chatglm3_mode:
  12. history_formatted.append({'role': 'user', 'content':old_chat['content']})
  13. else:
  14. tmp.append(old_chat['content'])
  15. elif old_chat['role'] == "AI" or old_chat['role'] == 'assistant':
  16. if chatglm3_mode:
  17. history_formatted.append({'role': 'assistant', 'metadata': '', 'content':old_chat['content']})
  18. else:
  19. tmp.append(old_chat['content'])
  20. history_formatted.append(tuple(tmp))
  21. tmp = []
  22. elif old_chat['role'] == "system":
  23. if chatglm3_mode:
  24. history_formatted.append({'role': 'system', 'content':"Answer the following questions as best as you can. You have access to the following tools:", "tools":json.loads(old_chat['content'])})
  25. else:
  26. continue
  27. return history_formatted
  28. def chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  29. yield str(len(prompt))+'字正在计算'
  30. if len(history_formatted)>0 and history_formatted[0]['role']=="system":
  31. if prompt.startswith("observation!"):
  32. prompt = prompt.replace("observation!", "")
  33. response, history = model.chat(tokenizer, prompt, history_formatted, role="observation",
  34. max_length=max_length, top_p=top_p, temperature=temperature)
  35. yield response
  36. else:
  37. response, history = model.chat(tokenizer, prompt, history_formatted,
  38. max_length=max_length, top_p=top_p, temperature=temperature)
  39. yield json.dumps(response)
  40. else:
  41. for response, history in model.stream_chat(tokenizer, prompt, history_formatted,
  42. max_length=max_length, top_p=top_p, temperature=temperature):
  43. yield response
  44. def sum_values(dict):
  45. total = 0
  46. for value in dict.values():
  47. total += value
  48. return total
  49. def dict_to_list(d):
  50. l = []
  51. for k, v in d.items():
  52. l.extend([k] * v)
  53. return l
  54. def load_model():
  55. global model, tokenizer
  56. from transformers import AutoModel, AutoTokenizer
  57. num_trans_layers = 28
  58. strategy = ('->'.join([x.strip() for x in settings.llm.strategy.split('->')])).replace('->', ' -> ')
  59. s = [x.strip().split(' ') for x in strategy.split('->')]
  60. print(s)
  61. if len(s)>1:
  62. from accelerate import dispatch_model
  63. start_device = int(s[0][0].split(':')[1])
  64. #根据路径名判断,如果是glm2则使用专用devicemap,参见https://github.com/THUDM/ChatGLM2-6B/blob/main/utils.py Line23
  65. if "chatglm2" in settings.llm.path.lower():
  66. device_map = {'transformer.embedding.word_embeddings': 0,
  67. 'transformer.encoder.final_layernorm': 0,
  68. 'transformer.output_layer': 0,
  69. 'transformer.rotary_pos_emb': 0,
  70. 'lm_head': 0}
  71. else:
  72. device_map = {'transformer.word_embeddings': start_device,
  73. 'transformer.final_layernorm': start_device, 'lm_head': start_device}
  74. n = {}
  75. for i in range(len(s)):
  76. si = s[i]
  77. if len(s[i]) > 2:
  78. ss = si[2]
  79. if ss.startswith('*'):
  80. n[int(si[0].split(':')[1])]=int(ss[1:])
  81. else:
  82. n[int(si[0].split(':')[1])] = num_trans_layers+2-sum_values(n)
  83. n[start_device] -= 2
  84. n = dict_to_list(n)
  85. for i in range(num_trans_layers):
  86. #根据路径名判断,如果是glm2则使用专用devicemap,参见https://github.com/THUDM/ChatGLM2-6B/blob/main/utils.py Line23
  87. if "chatglm2" in settings.llm.path.lower():
  88. device_map[f'transformer.encoder.layers.{i}'] = n[i]
  89. else:
  90. device_map[f'transformer.layers.{i}'] = n[i]
  91. device, precision = s[0][0], s[0][1]
  92. tokenizer = AutoTokenizer.from_pretrained(
  93. settings.llm.path, local_files_only=True, trust_remote_code=True,revision="v1.1.0")
  94. model = AutoModel.from_pretrained(
  95. settings.llm.path, local_files_only=True, trust_remote_code=True, revision="v1.1.0")
  96. if not (settings.llm.lora == '' or settings.llm.lora == None):
  97. print('Lora模型地址', settings.llm.lora)
  98. from peft import PeftModel
  99. model = PeftModel.from_pretrained(model, settings.llm.lora,adapter_name=settings.llm.lora)
  100. # 根据设备执行不同的操作
  101. if device == 'cpu':
  102. # 如果是cpu,不做任何操作
  103. pass
  104. elif device == 'cuda':
  105. # 如果是gpu,把模型移动到显卡
  106. import torch
  107. if "chatglm2" in settings.llm.path and "int4" in settings.llm.path:
  108. model = model.cuda()
  109. elif not (precision.startswith('fp16i') and torch.cuda.get_device_properties(0).total_memory < 1.4e+10):
  110. model = model.cuda()
  111. elif len(s)>1 and device.startswith('cuda:'):
  112. pass
  113. else:
  114. # 如果是其他设备,报错并退出程序
  115. print('Error: 不受支持的设备')
  116. exit()
  117. # 根据精度执行不同的操作
  118. if precision == 'fp16':
  119. # 如果是fp16,把模型转化为半精度
  120. model = model.half()
  121. elif precision == 'fp32':
  122. # 如果是fp32,把模型转化为全精度
  123. model = model.float()
  124. elif precision.startswith('fp16i'):
  125. # 如果是fp16i开头,把模型转化为指定的精度
  126. # 从字符串中提取精度的数字部分
  127. bits = int(precision[5:])
  128. # 调用quantize方法,传入精度参数
  129. model = model.quantize(bits)
  130. if device == 'cuda':
  131. model = model.cuda()
  132. model = model.half()
  133. elif precision.startswith('fp32i'):
  134. # 如果是fp32i开头,把模型转化为指定的精度
  135. # 从字符串中提取精度的数字部分
  136. bits = int(precision[5:])
  137. # 调用quantize方法,传入精度参数
  138. model = model.quantize(bits)
  139. if device == 'cuda':
  140. model = model.cuda()
  141. model = model.float()
  142. else:
  143. # 如果是其他精度,报错并退出程序
  144. print('Error: 不受支持的精度')
  145. exit()
  146. if len(s)>1:
  147. model = dispatch_model(model, device_map=device_map)
  148. model = model.eval()
  149. if not (settings.llm.lora == '' or settings.llm.lora == None):
  150. from bottle import route, response, request
  151. @route('/lora_load_adapter', method=("POST","OPTIONS"))
  152. def load_adapter():
  153. # allowCROS()
  154. try:
  155. data = request.json
  156. lora_path=data.get("lora_path")
  157. adapter_name=data.get("adapter_name")
  158. model.load_adapter(lora_path, adapter_name=adapter_name)
  159. return "保存成功"
  160. except Exception as e:
  161. return str(e)
  162. @route('/lora_set_adapter', method=("POST","OPTIONS"))
  163. def set_adapter():
  164. # allowCROS()
  165. try:
  166. data = request.json
  167. adapter_name=data.get("adapter_name")
  168. model.set_adapter(adapter_name)
  169. return "保存成功"
  170. except Exception as e:
  171. return str(e)
粤ICP备19079148号