llm_internlm.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import torch
  2. from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
  3. from transformers import AutoModelForCausalLM, AutoTokenizer
  4. import warnings
  5. import copy
  6. from typing import List, Optional, Callable, Optional
  7. from dataclasses import dataclass, asdict
  8. import torch.nn as nn
  9. from plugins.common import settings
  10. def chat_init(history):
  11. tmp = []
  12. # print(history)
  13. for i, old_chat in enumerate(history):
  14. if old_chat['role'] == "user":
  15. tmp.append(user_prompt.replace("{user}", old_chat['content']))
  16. elif old_chat['role'] == "AI":
  17. tmp.append(robot_prompt.replace("{robot}", old_chat['content']))
  18. else:
  19. continue
  20. history = ''.join(tmp)
  21. return history
  22. def chat_one(prompt, history, max_length, top_p, temperature, data):
  23. # if prompt.startswith("raw!"):
  24. # print("[raw mode]", end="")
  25. # prompt = prompt.replace("raw!", "")
  26. # else:
  27. # prompt = f"{user}{interface}{prompt}\n{answer}{interface}"
  28. generation_config = GenerationConfig(
  29. max_length=max_length,
  30. top_p=top_p,
  31. temperature=temperature,
  32. repetition_penalty=1.05
  33. )
  34. prompt = history + cur_query_prompt.replace("{user}", prompt)
  35. for i in generate_interactive(prompt, (generation_config),additional_eos_token_id=103028):
  36. yield i
  37. def load_model():
  38. global model, tokenizer
  39. model = AutoModelForCausalLM.from_pretrained(
  40. settings.llm.path, trust_remote_code=True).to(torch.bfloat16).cuda()
  41. tokenizer = AutoTokenizer.from_pretrained(
  42. settings.llm.path, trust_remote_code=True)
  43. @ torch.inference_mode()
  44. def generate_interactive(
  45. prompt,
  46. generation_config: Optional[GenerationConfig] = None,
  47. logits_processor: Optional[LogitsProcessorList] = None,
  48. stopping_criteria: Optional[StoppingCriteriaList] = None,
  49. prefix_allowed_tokens_fn: Optional[Callable[[
  50. int, torch.Tensor], List[int]]] = None,
  51. additional_eos_token_id: Optional[int] = None,
  52. **kwargs,
  53. ):
  54. inputs = tokenizer([prompt], padding=True, return_tensors="pt")
  55. input_length = len(inputs["input_ids"][0])
  56. for k, v in inputs.items():
  57. inputs[k] = v.cuda()
  58. input_ids = inputs["input_ids"]
  59. batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
  60. if generation_config is None:
  61. generation_config = model.generation_config
  62. generation_config = copy.deepcopy(generation_config)
  63. eos_token_id=[additional_eos_token_id]
  64. # 2. Set generation parameters if not already defined
  65. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  66. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  67. logits_processor = model._get_logits_processor(
  68. generation_config=generation_config,
  69. input_ids_seq_length=input_ids_seq_length,
  70. encoder_input_ids=input_ids,
  71. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  72. logits_processor=logits_processor,
  73. )
  74. stopping_criteria = model._get_stopping_criteria(
  75. generation_config=generation_config, stopping_criteria=stopping_criteria
  76. )
  77. logits_warper = model._get_logits_warper(generation_config)
  78. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  79. scores = None
  80. model_kwargs = generation_config.update(**kwargs)
  81. while True:
  82. model_inputs = model.prepare_inputs_for_generation(
  83. input_ids, **model_kwargs)
  84. # forward pass to get next token
  85. outputs = model(
  86. **model_inputs,
  87. return_dict=True,
  88. output_attentions=False,
  89. output_hidden_states=False,
  90. )
  91. next_token_logits = outputs.logits[:, -1, :]
  92. # pre-process distribution
  93. next_token_scores = logits_processor(input_ids, next_token_logits)
  94. next_token_scores = logits_warper(input_ids, next_token_scores)
  95. # sample
  96. probs = nn.functional.softmax(next_token_scores, dim=-1)
  97. if generation_config.do_sample:
  98. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  99. else:
  100. next_tokens = torch.argmax(probs, dim=-1)
  101. # update generated ids, model inputs, and length for next step
  102. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  103. model_kwargs = model._update_model_kwargs_for_generation(
  104. outputs, model_kwargs, is_encoder_decoder=False
  105. )
  106. unfinished_sequences = unfinished_sequences.mul(
  107. (min(next_tokens != i for i in eos_token_id)).long())
  108. output_token_ids = input_ids[0].cpu().tolist()
  109. output_token_ids = output_token_ids[input_length:]
  110. for each_eos_token_id in eos_token_id:
  111. if output_token_ids[-1] == each_eos_token_id:
  112. output_token_ids = output_token_ids[:-1]
  113. response = tokenizer.decode(output_token_ids)
  114. yield response
  115. # stop when each sentence is finished, or if we exceed the maximum length
  116. if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
  117. break
  118. user_prompt = "<|User|>:{user}<eoh>\n"
  119. robot_prompt = "<|Bot|>:{robot}<eoa>\n"
  120. cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
粤ICP备19079148号