generate_completions.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Generates completions from RWKV model based on a prompt.
  2. import argparse
  3. import os
  4. import time
  5. import sampling
  6. import rwkv_cpp_model
  7. import rwkv_cpp_shared_library
  8. from rwkv_tokenizer import get_tokenizer
  9. from typing import List
  10. # ======================================== Script settings ========================================
  11. prompt: str = """# rwkv.cpp
  12. This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
  13. Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**."""
  14. # How many completions to generate.
  15. generation_count: int = 3
  16. # Token count per single completion.
  17. tokens_per_generation: int = 100
  18. # Sampling settings.
  19. temperature: float = 0.8
  20. top_p: float = 0.5
  21. # =================================================================================================
  22. parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
  23. parser.add_argument('model_path', help='Path to RWKV model in ggml format')
  24. parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
  25. args = parser.parse_args()
  26. assert prompt != '', 'Prompt must not be empty'
  27. tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
  28. prompt_tokens = tokenizer_encode(prompt)
  29. library = rwkv_cpp_shared_library.load_rwkv_shared_library()
  30. print(f'System info: {library.rwkv_get_system_info_string()}')
  31. print('Loading RWKV model')
  32. model = rwkv_cpp_model.RWKVModel(library, args.model_path)
  33. prompt_token_count = len(prompt_tokens)
  34. print(f'{prompt_token_count} tokens in prompt')
  35. init_logits, init_state = None, None
  36. for token in prompt_tokens:
  37. init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
  38. for GENERATION in range(generation_count):
  39. print(f'\n--- Generation {GENERATION} ---\n')
  40. print(prompt, end='[')
  41. start = time.time()
  42. logits, state = init_logits.clone(), init_state.clone()
  43. for i in range(tokens_per_generation):
  44. token = sampling.sample_logits(logits, temperature, top_p)
  45. print(tokenizer.decode([token]), end='', flush=True)
  46. logits, state = model.eval(token, state, state, logits)
  47. delay = time.time() - start
  48. print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000))
粤ICP备19079148号