| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- from transformers import StoppingCriteriaList
- import torch
- from transformers import TextIteratorStreamer
- from plugins.common import settings
- from threading import Thread
- from transformers import AutoModelForCausalLM, AutoTokenizer
- user = settings.llm.user
- answer = settings.llm.answer
- interface = settings.llm.interface
- stopping_text = settings.llm.stopping_text
- if stopping_text:
- stopping_criteria = StoppingCriteriaList(
- [lambda input_ids, scores: tokenizer.decode(input_ids[0]).endswith(stopping_text)])
- else:
- stopping_criteria = []
- if settings.llm.path.lower().find("gptq") > -1:
- print("gptq mode!")
- gptq = True
- else:
- gptq = False
- class ThreadWithReturnValue(Thread):
- def run(self):
- if self._target is not None:
- self._return = self._target(*self._args, **self._kwargs)
- def join(self):
- super().join()
- return self._return
- def chat_init(history):
- tmp = []
- # print(history)
- for i, old_chat in enumerate(history):
- if old_chat['role'] == "user":
- tmp.append(f"{user}{interface}"+old_chat['content'])
- elif old_chat['role'] == "AI":
- tmp.append(f"{answer}{interface}"+old_chat['content'])
- else:
- continue
- history = '\n'.join(tmp)
- return history
- def chat_one(prompt, history, max_length, top_p, temperature, data):
- if prompt.startswith("raw!"):
- print("[raw mode]", end="")
- prompt = prompt.replace("raw!", "")
- else:
- prompt = f"{user}{interface}{prompt}\n{answer}{interface}"
- if history is None:
- history = ""
- else:
- history += '\n'
- prompt = history+prompt
- inputs = tokenizer(prompt, return_tensors='pt')
- if gptq:
- inputs = inputs.input_ids.cuda()
- else:
- inputs = inputs.to('cuda:0')
- yield str(len(prompt))+'字正在计算'
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=5)
- if gptq:
- thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
- inputs=inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1, stopping_criteria=stopping_criteria, streamer=streamer))
- else:
- thread = ThreadWithReturnValue(target=model.generate, kwargs=dict(
- inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, repetition_penalty=1.1, stopping_criteria=stopping_criteria, streamer=streamer))
- thread.start()
- generated_text = ""
- for new_text in streamer:
- if new_text != '':
- generated_text += new_text
- yield generated_text.removesuffix("</s>")
- pred = thread.join()
- output = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[
- len(prompt):]
- if stopping_text:
- if output.endswith(stopping_text):
- output = output[:-len(stopping_text)]
- yield output
- def load_model():
- global model, tokenizer
- if gptq:
- from auto_gptq import AutoGPTQForCausalLM
- tokenizer = AutoTokenizer.from_pretrained(
- settings.llm.path, use_fast=True)
- model = AutoGPTQForCausalLM.from_quantized(settings.llm.path,
- model_basename=settings.llm.basename,
- use_safetensors=True,
- trust_remote_code=True,
- device="cuda:0",
- use_triton=False,
- quantize_config=None)
- else:
- tokenizer = AutoTokenizer.from_pretrained(
- settings.llm.path, trust_remote_code=True, revision="1")
- model = AutoModelForCausalLM.from_pretrained(
- settings.llm.path, trust_remote_code=True,
- low_cpu_mem_usage=True,
- torch_dtype=torch.float16,
- device_map="cuda:0",
- revision="1")
- if not (settings.llm.lora == '' or settings.llm.lora == None):
- print('Lora模型地址', settings.llm.lora)
- from peft import PeftModel
- model = PeftModel.from_pretrained(
- model, settings.llm.lora, adapter_name=settings.llm.lora)
- if settings.llm.path.lower().find("13b")!=-1:
- model = model.quantize(8)
- model = model.cuda()
- model = model.eval()
- if not (settings.llm.lora == '' or settings.llm.lora == None):
- from bottle import route, request
- @route('/lora_load_adapter', method=("POST", "OPTIONS"))
- def load_adapter():
- # allowCROS()
- try:
- data = request.json
- lora_path = data.get("lora_path")
- adapter_name = data.get("adapter_name")
- model.load_adapter(lora_path, adapter_name=adapter_name)
- return "保存成功"
- except Exception as e:
- return str(e)
- @route('/lora_set_adapter', method=("POST", "OPTIONS"))
- def set_adapter():
- # allowCROS()
- try:
- data = request.json
- adapter_name = data.get("adapter_name")
- model.set_adapter(adapter_name)
- return "保存成功"
- except Exception as e:
- return str(e)
|