measure_pexplexity.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Measures perplexity and per-token latency of an RWKV model on a given text file.
  2. # Perplexity is defined here as exp() of average cross-entropy loss.
  3. # Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
  4. import os
  5. import time
  6. import argparse
  7. import torch
  8. import rwkv_cpp_model
  9. import rwkv_cpp_shared_library
  10. from rwkv_tokenizer import get_tokenizer
  11. def parse_args():
  12. parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
  13. parser.add_argument('model_path', help='Path to model checkpoint file', type=str)
  14. parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
  15. parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
  16. parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
  17. parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
  18. return parser.parse_args()
  19. args = parse_args()
  20. print('Loading text')
  21. text: str = open(args.text_path, encoding='utf-8').read()
  22. tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
  23. tokens = tokenizer_encode(text)
  24. token_count: int = len(tokens)
  25. print(f'{token_count} tokens in the text')
  26. token_limit: int = args.token_limit
  27. assert token_limit == -1 or token_limit > 0, 'Invalid token_limit'
  28. if token_limit != -1 and token_count > token_limit:
  29. tokens = tokens[0:token_limit]
  30. token_count = token_limit
  31. print(f'Text was limited to {token_limit} tokens')
  32. assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
  33. # ---
  34. def format_loss(loss: torch.Tensor) -> str:
  35. return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
  36. def format_loss_with_perplexity(loss: torch.Tensor) -> str:
  37. return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
  38. # ---
  39. model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
  40. rwkv_cpp_shared_library.load_rwkv_shared_library(),
  41. args.model_path
  42. )
  43. logits, state = None, None
  44. loss_sum: torch.Tensor = torch.tensor([0.0])
  45. loss_count: int = 0
  46. start: float = time.time()
  47. run_count: int = token_count - 1
  48. for i in range(run_count):
  49. token: int = tokens[i]
  50. target: int = tokens[i + 1]
  51. logits, state = model.eval(token, state, state, logits)
  52. if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
  53. losses = torch.tensor([
  54. torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
  55. ])
  56. loss_sum += losses
  57. loss_count += 1
  58. if run_count <= 5 or i % (run_count // 10) == 0:
  59. avg_loss_so_far = loss_sum / loss_count
  60. duration: float = time.time() - start
  61. duration_per_token: float = duration / (i + 1)
  62. runs_remaining: int = run_count - i - 1
  63. duration_remaining: int = int(runs_remaining * duration_per_token)
  64. print(f'Token #{i}/{token_count}, '
  65. f'{int(100.0 * i / token_count)}%, '
  66. f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
  67. if loss_count > 0:
  68. print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
  69. else:
  70. print()
  71. print()
  72. print(f'Model: {os.path.basename(args.model_path)}, '
  73. f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
  74. f'skipped {args.ignore_first_n_tokens} tokens, '
  75. f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}, '
  76. f'latency {int((time.time() - start) * 1000 / run_count)} ms per token')
粤ICP备19079148号