llm_rwkv.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import torch
  2. from plugins.common import settings
  3. import time
  4. from copy import deepcopy
  5. import threading
  6. import time
  7. import math
  8. import re
  9. from typing import List,Dict
  10. interface = ":"
  11. tokenizers_type = "world"
  12. user = "User"
  13. answer = "Assistant"
  14. tokenizers_file = "rwkv_vocab_v20230424"
  15. states = {}
  16. presencePenalty = 0.2
  17. countPenalty = 0.2
  18. if settings.llm.presence_penalty:
  19. presencePenalty=settings.llm.presence_penalty
  20. if settings.llm.count_penalty:
  21. countPenalty=settings.llm.count_penalty
  22. class State(object):
  23. def __init__(self, state):
  24. self.state = [tensor.cpu() for tensor in state] if device != "cpu" else state
  25. self.touch()
  26. def get(self):
  27. self.touch()
  28. return [tensor.to(device) for tensor in self.state] if device != "cpu" else deepcopy(self.state)
  29. def touch(self):
  30. self.time = time.time()
  31. def gc_states():
  32. while True:
  33. time.sleep(3)
  34. if len(states) > 1000:
  35. oldest = [math.inf, '']
  36. for i in states:
  37. if states[i].time < oldest[0]:
  38. oldest = [states[i].time, i]
  39. del states[oldest[1]]
  40. thread_load_model = threading.Thread(target=gc_states)
  41. thread_load_model.start()
  42. device = settings.llm.state_source_device or 'cpu'
  43. if settings.llm.strategy.startswith("Q"):
  44. runtime = "cpp"
  45. from typing import Optional
  46. import tokenizers
  47. from llms.rwkvcpp.sampling import sample_logits
  48. from llms.rwkvcpp.rwkv_tokenizer import get_tokenizer
  49. logits: Optional[torch.Tensor] = None
  50. state: Optional[torch.Tensor] = None
  51. END_OF_LINE_TOKEN: int = 187
  52. def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
  53. global logits, state
  54. for _token in _tokens:
  55. logits, state = model.eval(_token, state, state, logits)
  56. logits[END_OF_LINE_TOKEN] += new_line_logit_bias
  57. def chat_init(history):
  58. global state, logits
  59. if settings.llm.historymode != 'string':
  60. if history is not None and len(history) > 0:
  61. pass
  62. else:
  63. state = None
  64. logits = None
  65. else:
  66. tmp = []
  67. # print(history)
  68. for i, old_chat in enumerate(history):
  69. if old_chat['role'] == "user":
  70. tmp.append(f"{user}{interface} "+old_chat['content'])
  71. elif old_chat['role'] == "AI":
  72. tmp.append(f"{answer}{interface} "+old_chat['content'])
  73. else:
  74. continue
  75. history = '\n\n'.join(tmp)
  76. state = None
  77. logits = None
  78. return history
  79. def chat_one(prompt, history, max_length, top_p, temperature, data):
  80. global state, resultChat, token_stop, logits
  81. token_count = max_length
  82. token_stop = [0]
  83. resultChat = ""
  84. if prompt.startswith("raw!"):
  85. print("[RWKV raw mode]", end="")
  86. ctx = prompt.replace("raw!", "")
  87. else:
  88. ctx = f"\n\n{user}{interface} {prompt}\n\n{answer}{interface}"
  89. if settings.llm.historymode == 'string':
  90. ctx = history+ctx
  91. yield str(len(ctx))+'字正在计算'
  92. new = ctx.strip()
  93. print(f'{new}', end='')
  94. process_tokens(tokenizer_encode(new),
  95. new_line_logit_bias=-999999999)
  96. accumulated_tokens: List[int] = []
  97. token_counts: Dict[int, int] = {}
  98. for i in range(int(token_count)):
  99. for n in token_counts:
  100. logits[n] -= presencePenalty + token_counts[n] * countPenalty
  101. token: int = sample_logits(logits, temperature, top_p)
  102. if token in token_stop:
  103. break
  104. if token not in token_counts:
  105. token_counts[token] = 1
  106. else:
  107. token_counts[token] += 1
  108. process_tokens([token])
  109. # Avoid UTF-8 display issues
  110. accumulated_tokens += [token]
  111. decoded: str = tokenizer.decode(accumulated_tokens)
  112. if '\uFFFD' not in decoded:
  113. resultChat = resultChat + decoded
  114. if resultChat.endswith('\n\n') or resultChat.endswith(f"{user}{interface}") or resultChat.endswith(f"{answer}{interface}"):
  115. resultChat = remove_suffix(
  116. remove_suffix(
  117. remove_suffix(
  118. remove_suffix(resultChat, f"{user}{interface}"), f"{answer}{interface}"),
  119. '\n'),
  120. '\n')
  121. yield resultChat
  122. break
  123. yield resultChat
  124. accumulated_tokens = []
  125. def remove_suffix(input_string, suffix): # 兼容python3.8
  126. if suffix and input_string.endswith(suffix):
  127. return input_string[:-len(suffix)]
  128. return input_string
  129. model = None
  130. state = None
  131. tokenizer_encode = None
  132. def load_model():
  133. global model, tokenizer, tokenizer_encode
  134. from llms.rwkvcpp.rwkv_cpp_shared_library import load_rwkv_shared_library
  135. library = load_rwkv_shared_library()
  136. print(f'System info: {library.rwkv_get_system_info_string()}')
  137. print('Loading RWKV model')
  138. from llms.rwkvcpp.rwkv_cpp_model import RWKVModel
  139. try:
  140. cpu_count = int(settings.llm.strategy.split('->')[1])
  141. model = RWKVModel(library, settings.llm.path, cpu_count)
  142. except:
  143. model = RWKVModel(library, settings.llm.path)
  144. #print('Loading 20B tokenizer')
  145. #tokenizer = tokenizers.Tokenizer.from_file(tokenizers_file)
  146. tokenizer , tokenizer_encode = get_tokenizer(tokenizers_type)
  147. else:
  148. runtime = "torch"
  149. def chat_init(history):
  150. tmp = []
  151. # print(history)
  152. raw_mode=False
  153. for i, old_chat in enumerate(history):
  154. if old_chat['role'] == "user":
  155. if old_chat['content'].startswith("raw!"):
  156. raw_mode=True
  157. tmp.append(old_chat['content'])
  158. else:
  159. raw_mode=False
  160. tmp.append(f"{user}{interface} "+old_chat['content'])
  161. elif old_chat['role'] == "AI":
  162. if raw_mode:
  163. tmp[-1]+=old_chat['content']
  164. else:
  165. tmp.append(f"{answer}{interface} "+old_chat['content'])
  166. else:
  167. continue
  168. history = '\n\n'.join(tmp)
  169. return history
  170. def chat_one(prompt, history, max_length, top_p, temperature, data):
  171. cfg_factor = data.get('cfg_factor')
  172. cfg_ctx = data.get('cfg_ctx')
  173. cfg_ctx_history = data.get('cfg_ctx_history')
  174. if cfg_factor is None :
  175. cfg_factor=1
  176. elif cfg_factor!=1:
  177. if cfg_ctx is None:
  178. cfg_ctx=prompt
  179. if cfg_ctx_history is None:
  180. cfg_ctx_history=[]
  181. print('CFG参数',[cfg_factor,cfg_ctx,cfg_ctx_history])
  182. cfg_ctx_history=chat_init(cfg_ctx_history)
  183. if cfg_ctx_history == "":
  184. pass
  185. else:
  186. cfg_ctx_history = cfg_ctx_history+'\n\n'
  187. token_count = max_length
  188. if history is None or history == "":
  189. history = ""
  190. else:
  191. history = history+'\n\n'
  192. args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
  193. alpha_frequency=countPenalty,
  194. alpha_presence=presencePenalty,
  195. token_ban=[], # ban the generation of some tokens
  196. token_stop=[0]) # stop generation whenever you see any token here
  197. if prompt.startswith("raw!"):
  198. print("[raw mode]", end="")
  199. ctx = prompt.replace("raw!", "")
  200. ctx = re.sub('\\{user\\}',user, ctx)
  201. ctx = re.sub('\\{answer\\}',answer, ctx)
  202. ctx = re.sub('\\{bot\\}',answer, ctx)
  203. ctx = re.sub('\\{interface\\}',interface, ctx)
  204. raw_mode=True
  205. else:
  206. ctx = f"{user}{interface} {prompt}\n\n{answer}{interface}"
  207. raw_mode=False
  208. # print(ctx)
  209. state = None
  210. history_in_ctx=False
  211. try:
  212. state = states[history].get()
  213. print("[match state]", end="")
  214. except Exception as e:
  215. ctx = history+ctx
  216. history_in_ctx=True
  217. print("[default stste]", end="")
  218. if not raw_mode:
  219. state = states['default'].get()
  220. # print([history],states)
  221. all_tokens = []
  222. out_last = 0
  223. occurrence = {}
  224. tokens = pipeline.encode(ctx)
  225. if cfg_factor!=1:
  226. cfg_token = pipeline.encode(cfg_ctx_history+cfg_ctx)
  227. cfg_state=states['default'].get() #todo:使用缓存的state
  228. response = ''
  229. yield str(len(ctx))+'字正在计算\n'+str(len(tokens))+" tokens"
  230. for i in range(int(token_count)):
  231. out, state = model.forward(tokens if i == 0 else [token], state)
  232. if cfg_factor!=1:
  233. cfg_out, cfg_state = model.forward(cfg_token if i == 0 else [token], cfg_state)
  234. out = out * cfg_factor + cfg_out * (1 - cfg_factor)
  235. for n in args.token_ban:
  236. out[n] = -float('inf')
  237. for n in occurrence:
  238. out[n] -= (args.alpha_presence + occurrence[n]
  239. * args.alpha_frequency)
  240. token = pipeline.sample_logits(
  241. out, temperature=args.temperature, top_p=args.top_p)
  242. if token in args.token_stop:
  243. break
  244. all_tokens += [token]
  245. for occurrence_i in occurrence:
  246. occurrence[occurrence_i]*=0.996
  247. if token not in occurrence:
  248. occurrence[token] = 1
  249. else:
  250. occurrence[token] += 1
  251. tmp = pipeline.decode(all_tokens[out_last:])
  252. if '\ufffd' not in tmp:
  253. response += tmp
  254. if response.endswith(f"{user}{interface}"):
  255. response = remove_suffix(response,
  256. f"{user}{interface}"
  257. )
  258. break
  259. # print(tmp, end='')
  260. out_last = i + 1
  261. yield response.strip()
  262. yield response.strip()
  263. if raw_mode:
  264. if not history_in_ctx:
  265. prompt=history+prompt
  266. states[prompt+response.strip()+'\n\n'] = State(state)
  267. else:
  268. if not history_in_ctx:
  269. ctx=history+ctx
  270. states[ctx+' '+response.strip()+'\n\n'] = State(state)
  271. def remove_suffix(input_string, suffix): # 兼容python3.8
  272. if suffix and input_string.endswith(suffix):
  273. return input_string[:-len(suffix)]
  274. return input_string
  275. pipeline = None
  276. PIPELINE_ARGS = None
  277. model = None
  278. def load_model():
  279. global pipeline, PIPELINE_ARGS, model
  280. import os
  281. os.environ['RWKV_JIT_ON'] = '1'
  282. if (os.environ.get('RWKV_CUDA_ON') == '' or os.environ.get('RWKV_CUDA_ON') == None):
  283. os.environ["RWKV_CUDA_ON"] = '0'
  284. # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
  285. from rwkv.model import RWKV # pip install rwkv
  286. model = RWKV(model=settings.llm.path, strategy=settings.llm.strategy)
  287. # if settings.rwkv_lora_path == '':
  288. # else:
  289. # with torch.no_grad():
  290. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  291. try:
  292. pipeline = PIPELINE(model, tokenizers_file)
  293. except:
  294. print(
  295. "不能使用world请更新rwkv:pip install -U rwkv -i https://mirrors.aliyun.com/pypi/simple")
  296. out, state = model.forward(pipeline.encode(f'''{user}{interface} hi
  297. {answer}{interface} Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
  298. '''), None)
  299. states['default'] = State(state)
粤ICP备19079148号