llm_llama.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from plugins.common import settings
  2. if settings.llm.strategy.startswith("Q"):
  3. runtime = "cpp"
  4. def chat_init(history):
  5. history_formatted = None
  6. if history is not None:
  7. history_formatted = ""
  8. for i, old_chat in enumerate(history):
  9. if old_chat['role'] == "user":
  10. history_formatted+="Q: "+old_chat['content']+'\n'
  11. elif old_chat['role'] == "AI" or old_chat['role'] == 'assistant':
  12. history_formatted+=" A: "+old_chat['content']+'\n'
  13. else:
  14. continue
  15. return history_formatted+" "
  16. def chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  17. prompt=history_formatted+"Human: %s\nAssistant: "%prompt
  18. stream = model(prompt,
  19. stop=["Human:","### Hum",], temperature=temperature,max_tokens=max_length, top_p=top_p,stream=True)
  20. # print(output['choices'])
  21. text=""
  22. for output in stream:
  23. text+=output["choices"][0]["text"]
  24. yield text
  25. def load_model():
  26. global model
  27. from llama_cpp import Llama
  28. try:
  29. cpu_count = int(settings.llm.strategy.split('->')[1])
  30. model = Llama(model_path=settings.llm.path,use_mlock=True,n_ctx=4096,n_threads=cpu_count)
  31. except:
  32. model = Llama(model_path=settings.llm.path,use_mlock=True,n_ctx=4096)
  33. else:
  34. runtime = "torch"
  35. user = "Human"
  36. answer = "Assistant"
  37. interface = ":"
  38. import torch
  39. import gc
  40. from transformers.generation.logits_process import (
  41. LogitsProcessorList,
  42. RepetitionPenaltyLogitsProcessor,
  43. TemperatureLogitsWarper,
  44. TopKLogitsWarper,
  45. TopPLogitsWarper,
  46. )
  47. def chat_init(history):
  48. tmp = []
  49. # print(history)
  50. for i, old_chat in enumerate(history):
  51. if old_chat['role'] == "user":
  52. tmp.append(f"{user}{interface} "+old_chat['content'])
  53. elif old_chat['role'] == "AI":
  54. tmp.append(f"{answer}{interface} "+old_chat['content'])
  55. else:
  56. continue
  57. history='\n\n'.join(tmp)
  58. return history
  59. def partial_stop(output, stop_str):
  60. for i in range(0, min(len(output), len(stop_str))):
  61. if stop_str.startswith(output[-i:]):
  62. return True
  63. return False
  64. def prepare_logits_processor(
  65. temperature: float, repetition_penalty: float, top_p: float, top_k: int
  66. ) -> LogitsProcessorList:
  67. processor_list = LogitsProcessorList()
  68. # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
  69. if temperature >= 1e-5 and temperature != 1.0:
  70. processor_list.append(TemperatureLogitsWarper(temperature))
  71. if repetition_penalty > 1.0:
  72. processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
  73. if 1e-8 <= top_p < 1.0:
  74. processor_list.append(TopPLogitsWarper(top_p))
  75. if top_k > 0:
  76. processor_list.append(TopKLogitsWarper(top_k))
  77. return processor_list
  78. @torch.inference_mode()
  79. def generate_stream(
  80. model, tokenizer, query: str, max_length=2048, do_sample=True, top_p=1.0, temperature=1.0, logits_processor=None
  81. ):
  82. prompt = query
  83. len_prompt = len(prompt)
  84. temperature = temperature
  85. repetition_penalty = 1.0
  86. top_p = top_p
  87. top_k = -1 # -1 means disable
  88. max_new_tokens = 256
  89. stop_str = '\n\n\n'
  90. echo = False
  91. stop_token_ids = []
  92. stop_token_ids.append(tokenizer.eos_token_id)
  93. device = 'cuda'
  94. stream_interval = 2
  95. logits_processor = prepare_logits_processor(
  96. temperature, repetition_penalty, top_p, top_k
  97. )
  98. input_ids = tokenizer(prompt).input_ids
  99. input_echo_len = len(input_ids)
  100. output_ids = list(input_ids)
  101. max_src_len = max_length - max_new_tokens - 8
  102. input_ids = input_ids[-max_src_len:]
  103. past_key_values = out = None
  104. for i in range(max_new_tokens):
  105. if i == 0:
  106. if model.config.is_encoder_decoder:
  107. out = model.decoder(
  108. input_ids=start_ids,
  109. encoder_hidden_states=encoder_output,
  110. use_cache=True,
  111. )
  112. logits = model.lm_head(out[0])
  113. else:
  114. out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
  115. logits = out.logits
  116. past_key_values = out.past_key_values
  117. else:
  118. out = model(
  119. input_ids=torch.as_tensor([[token]], device=device),
  120. use_cache=True,
  121. past_key_values=past_key_values,
  122. )
  123. logits = out.logits
  124. past_key_values = out.past_key_values
  125. if logits_processor:
  126. if repetition_penalty > 1.0:
  127. tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
  128. else:
  129. tmp_output_ids = None
  130. last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
  131. else:
  132. last_token_logits = logits[0, -1, :]
  133. if temperature < 1e-5 or top_p < 1e-8: # greedy
  134. token = int(torch.argmax(last_token_logits))
  135. else:
  136. probs = torch.softmax(last_token_logits, dim=-1)
  137. token = int(torch.multinomial(probs, num_samples=1))
  138. output_ids.append(token)
  139. if token in stop_token_ids:
  140. stopped = True
  141. else:
  142. stopped = False
  143. if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
  144. if echo:
  145. tmp_output_ids = output_ids
  146. rfind_start = len_prompt
  147. else:
  148. tmp_output_ids = output_ids[input_echo_len:]
  149. rfind_start = 0
  150. output = tokenizer.decode(
  151. tmp_output_ids,
  152. skip_special_tokens=True,
  153. spaces_between_special_tokens=False,
  154. )
  155. partially_stopped = False
  156. if stop_str:
  157. if isinstance(stop_str, str):
  158. pos = output.rfind(stop_str, rfind_start)
  159. if pos != -1:
  160. output = output[:pos]
  161. stopped = True
  162. else:
  163. partially_stopped = partial_stop(output, stop_str)
  164. elif isinstance(stop_str, Iterable):
  165. for each_stop in stop_str:
  166. pos = output.rfind(each_stop, rfind_start)
  167. if pos != -1:
  168. output = output[:pos]
  169. stopped = True
  170. break
  171. else:
  172. partially_stopped = partial_stop(output, each_stop)
  173. if partially_stopped:
  174. break
  175. else:
  176. raise ValueError("Invalid stop field type.")
  177. # prevent yielding partial stop sequence
  178. if not partially_stopped:
  179. yield {
  180. "text": output,
  181. "usage": {
  182. "prompt_tokens": input_echo_len,
  183. "completion_tokens": i,
  184. "total_tokens": input_echo_len + i,
  185. },
  186. "finish_reason": None,
  187. }
  188. if stopped:
  189. break
  190. # finish stream event, which contains finish reason
  191. if i == max_new_tokens - 1:
  192. finish_reason = "length"
  193. elif stopped:
  194. finish_reason = "stop"
  195. else:
  196. finish_reason = None
  197. yield {
  198. "text": output,
  199. "usage": {
  200. "prompt_tokens": input_echo_len,
  201. "completion_tokens": i,
  202. "total_tokens": input_echo_len + i,
  203. },
  204. "finish_reason": finish_reason,
  205. }
  206. # clean
  207. del past_key_values, out
  208. gc.collect()
  209. torch.cuda.empty_cache()
  210. def chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  211. if prompt.startswith("raw!"):
  212. print("LLAMA raw mode!")
  213. ctx=prompt.replace("raw!","")
  214. else:
  215. ctx = f"\n\n{user}{interface} {prompt}\n\n{answer}{interface}"
  216. ctx=history_formatted+ctx
  217. ctx = ctx.strip('\n')
  218. yield str(len(ctx))+'字正在计算'
  219. for response in generate_stream(model,tokenizer, ctx,
  220. max_length=max_length, top_p=top_p, temperature=temperature):
  221. yield response['text']
  222. def sum_values(dict):
  223. total = 0
  224. for value in dict.values():
  225. total += value
  226. return total
  227. def dict_to_list(d):
  228. l = []
  229. for k, v in d.items():
  230. l.extend([k] * v)
  231. return l
  232. def load_model():
  233. global model, tokenizer
  234. from transformers import AutoModelForCausalLM, AutoTokenizer
  235. import torch
  236. num_trans_layers = 28
  237. strategy = ('->'.join([x.strip() for x in settings.llm.strategy.split('->')])).replace('->', ' -> ')
  238. s = [x.strip().split(' ') for x in strategy.split('->')]
  239. print(s)
  240. if len(s)>1:
  241. from accelerate import dispatch_model
  242. start_device = int(s[0][0].split(':')[1])
  243. device_map = {'transformer.word_embeddings': start_device,
  244. 'transformer.final_layernorm': start_device, 'lm_head': start_device}
  245. n = {}
  246. for i in range(len(s)):
  247. si = s[i]
  248. if len(s[i]) > 2:
  249. ss = si[2]
  250. if ss.startswith('*'):
  251. n[int(si[0].split(':')[1])]=int(ss[1:])
  252. else:
  253. n[int(si[0].split(':')[1])] = num_trans_layers+2-sum_values(n)
  254. n[start_device] -= 2
  255. n = dict_to_list(n)
  256. for i in range(num_trans_layers):
  257. device_map[f'transformer.layers.{i}'] = n[i]
  258. tokenizer = AutoTokenizer.from_pretrained(
  259. settings.llm.path, use_fast=False)
  260. model = AutoModelForCausalLM.from_pretrained(
  261. settings.llm.path, low_cpu_mem_usage=True, torch_dtype=torch.float16)
  262. if not (settings.llm.lora == '' or settings.llm.lora == None):
  263. print('Lora模型地址', settings.llm.lora)
  264. from peft import PeftModel
  265. model = PeftModel.from_pretrained(model, settings.llm.lora,adapter_name=settings.llm.lora)
  266. device, precision = s[0][0], s[0][1]
  267. # 根据设备执行不同的操作
  268. if device == 'cpu':
  269. # 如果是cpu,不做任何操作
  270. pass
  271. elif device == 'cuda':
  272. # 如果是gpu,把模型移动到显卡
  273. import torch
  274. if not (precision.startswith('fp16i') and torch.cuda.get_device_properties(0).total_memory < 1.4e+10):
  275. model = model.cuda()
  276. elif len(s)>1 and device.startswith('cuda:'):
  277. pass
  278. else:
  279. # 如果是其他设备,报错并退出程序
  280. print('Error: 不受支持的设备')
  281. exit()
  282. # 根据精度执行不同的操作
  283. if precision == 'fp16':
  284. # 如果是fp16,把模型转化为半精度
  285. model = model.half()
  286. elif precision == 'fp32':
  287. # 如果是fp32,把模型转化为全精度
  288. model = model.float()
  289. elif precision.startswith('fp16i'):
  290. # 如果是fp16i开头,把模型转化为指定的精度
  291. # 从字符串中提取精度的数字部分
  292. bits = int(precision[5:])
  293. # 调用quantize方法,传入精度参数
  294. model = model.quantize(bits)
  295. if device == 'cuda':
  296. model = model.cuda()
  297. model = model.half()
  298. elif precision.startswith('fp32i'):
  299. # 如果是fp32i开头,把模型转化为指定的精度
  300. # 从字符串中提取精度的数字部分
  301. bits = int(precision[5:])
  302. # 调用quantize方法,传入精度参数
  303. model = model.quantize(bits)
  304. if device == 'cuda':
  305. model = model.cuda()
  306. model = model.float()
  307. else:
  308. # 如果是其他精度,报错并退出程序
  309. print('Error: 不受支持的精度')
  310. exit()
  311. if len(s)>1:
  312. model = dispatch_model(model, device_map=device_map)
  313. model = model.eval()
粤ICP备19079148号