llm_aquila.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import os
  2. from typing import List, Any
  3. from enum import auto, Enum
  4. import dataclasses
  5. from flagai.data.tokenizer import Tokenizer
  6. from flagai.model.predictor.aquila import aquila_generate
  7. from flagai.model.predictor.predictor import Predictor
  8. from flagai.auto_model.auto_loader import AutoLoader
  9. from plugins.common import settings
  10. import torch
  11. def chat_init(history):
  12. history_formatted = None
  13. history_formatted = default_conversation.copy()
  14. if history is not None:
  15. tmp = []
  16. for i, old_chat in enumerate(history):
  17. if len(tmp) == 0 and old_chat['role'] == "user":
  18. history_formatted.append_message(history_formatted.roles[0], old_chat['content'])
  19. elif old_chat['role'] == "AI" or old_chat['role'] == 'assistant':
  20. history_formatted.append_message(history_formatted.roles[0], old_chat['content'])
  21. else:
  22. continue
  23. return history_formatted
  24. def chat_one(prompt, history_formatted, max_length, top_p, temperature, data):
  25. history_formatted.append_message(history_formatted.roles[0], prompt)
  26. history_formatted.append_message(history_formatted.roles[1], None)
  27. prompt=history_formatted.get_prompt()
  28. tokens = tokenizer.encode_plus(
  29. f"{prompt}", None, max_length=None)['input_ids']
  30. tokens = tokens[1:-1]
  31. yield str(len(prompt))+'字正在计算\n'+str(len(tokens))+"tokens"
  32. with torch.no_grad():
  33. out = aquila_generate(tokenizer, model, [
  34. prompt], max_gen_len= max_length, top_p=top_p, prompts_tokens=[tokens])
  35. yield(out)
  36. def load_model():
  37. global model, tokenizer
  38. state_dict, model_name = os.path.split(settings.llm.path)
  39. loader = AutoLoader(
  40. "lm",
  41. model_dir=state_dict,
  42. model_name=model_name,
  43. use_cache=True)
  44. model = loader.get_model()
  45. tokenizer = loader.get_tokenizer()
  46. model.eval()
  47. model.half()
  48. model.cuda()
  49. text = "北京为什么是中国的首都?"
  50. def pack_obj(text):
  51. obj = dict()
  52. obj['id'] = 'demo'
  53. obj['conversations'] = []
  54. human = dict()
  55. human['from'] = 'human'
  56. human['value'] = text
  57. obj['conversations'].append(human)
  58. # dummy bot
  59. bot = dict()
  60. bot['from'] = 'gpt'
  61. bot['value'] = ''
  62. obj['conversations'].append(bot)
  63. obj['instruction'] = ''
  64. return obj
  65. def delete_last_bot_end_singal(convo_obj):
  66. conversations = convo_obj['conversations']
  67. assert len(conversations) > 0 and len(conversations) % 2 == 0
  68. assert conversations[0]['from'] == 'human'
  69. last_bot = conversations[len(conversations)-1]
  70. assert last_bot['from'] == 'gpt'
  71. # from _add_speaker_and_signal
  72. END_SIGNAL = "\n"
  73. len_end_singal = len(END_SIGNAL)
  74. len_last_bot_value = len(last_bot['value'])
  75. last_bot['value'] = last_bot['value'][:len_last_bot_value-len_end_singal]
  76. return
  77. def convo_tokenize(convo_obj, tokenizer):
  78. chat_desc = convo_obj['chat_desc']
  79. instruction = convo_obj['instruction']
  80. conversations = convo_obj['conversations']
  81. # chat_desc
  82. example = tokenizer.encode_plus(
  83. f"{chat_desc}", None, max_length=None)['input_ids']
  84. EOS_TOKEN = example[-1]
  85. example = example[:-1] # remove eos
  86. # instruction
  87. instruction = tokenizer.encode_plus(
  88. f"{instruction}", None, max_length=None)['input_ids']
  89. instruction = instruction[1:-1] # remove bos & eos
  90. example += instruction
  91. for conversation in conversations:
  92. role = conversation['from']
  93. content = conversation['value']
  94. print(f"role {role}, raw content {content}")
  95. content = tokenizer.encode_plus(
  96. f"{content}", None, max_length=None)['input_ids']
  97. content = content[1:-1] # remove bos & eos
  98. print(f"role {role}, content {content}")
  99. example += content
  100. return example
  101. class SeparatorStyle(Enum):
  102. """Different separator style."""
  103. SINGLE = auto()
  104. TWO = auto()
  105. @dataclasses.dataclass
  106. class Conversation:
  107. """A class that keeps all conversation history."""
  108. system: str
  109. instruction: str
  110. roles: List[str]
  111. messages: List[List[str]]
  112. offset: int
  113. sep_style: SeparatorStyle = SeparatorStyle.SINGLE
  114. sep: str = "###"
  115. sep2: str = None
  116. skip_next: bool = False
  117. conv_id: Any = None
  118. def get_prompt(self):
  119. if self.sep_style == SeparatorStyle.SINGLE:
  120. ret = self.system + self.sep
  121. if self.instruction is not None and len(self.instruction) > 0:
  122. ret += self.roles[2] + ": " + self.instruction + self.sep
  123. for role, message in self.messages:
  124. if message:
  125. ret += role + ": " + message + self.sep
  126. else:
  127. ret += role + ":"
  128. return ret
  129. elif self.sep_style == SeparatorStyle.TWO:
  130. seps = [self.sep, self.sep2]
  131. ret = self.system + seps[0]
  132. if self.instruction is not None and len(self.instruction) > 0:
  133. ret += self.roles[2] + ": " + self.instruction + self.sep
  134. for i, (role, message) in enumerate(self.messages):
  135. if message:
  136. ret += role + ": " + message + seps[i % 2]
  137. else:
  138. ret += role + ":"
  139. return ret
  140. else:
  141. raise ValueError(f"Invalid style: {self.sep_style}")
  142. def append_message(self, role, message):
  143. self.messages.append([role, message])
  144. def to_gradio_chatbot(self):
  145. ret = []
  146. for i, (role, msg) in enumerate(self.messages[self.offset:]):
  147. if i % 2 == 0:
  148. ret.append([msg, None])
  149. else:
  150. ret[-1][-1] = msg
  151. return ret
  152. def copy(self):
  153. return Conversation(
  154. system=self.system,
  155. instruction=self.instruction,
  156. roles=self.roles,
  157. messages=[[x, y] for x, y in self.messages],
  158. offset=self.offset,
  159. sep_style=self.sep_style,
  160. sep=self.sep,
  161. sep2=self.sep2,
  162. conv_id=self.conv_id)
  163. def dict(self):
  164. return {
  165. "system": self.system,
  166. "instruction": self.instruction,
  167. "roles": self.roles,
  168. "messages": self.messages,
  169. "offset": self.offset,
  170. "sep": self.sep,
  171. "sep2": self.sep2,
  172. "conv_id": self.conv_id,
  173. }
  174. conv_v1 = Conversation(
  175. system="A chat between a curious human and an artificial intelligence assistant. "
  176. "The assistant gives helpful, detailed, and polite answers to the human's questions.",
  177. instruction="",
  178. roles=("Human", "Assistant", "System"),
  179. messages=(),
  180. offset=0,
  181. sep_style=SeparatorStyle.SINGLE,
  182. sep="###",
  183. )
  184. conv_v1_2 = Conversation(
  185. system="A chat between a curious human and an artificial intelligence assistant. "
  186. "The assistant gives helpful, detailed, and polite answers to the human's questions.",
  187. instruction="",
  188. roles=("Human", "Assistant", "System"),
  189. messages=(),
  190. offset=0,
  191. sep_style=SeparatorStyle.SINGLE,
  192. sep="###",
  193. )
  194. conv_bair_v1 = Conversation(
  195. system="BEGINNING OF CONVERSATION:",
  196. instruction="",
  197. roles=("USER", "GPT", "System"),
  198. messages=(),
  199. offset=0,
  200. sep_style=SeparatorStyle.TWO,
  201. sep=" ",
  202. sep2="</s>",
  203. )
  204. default_conversation = conv_v1_2
粤ICP备19079148号