convert_pytorch_to_ggml.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
  2. # Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
  3. # Get model checkpoints from https://huggingface.co/BlinkDL
  4. # See FILE_FORMAT.md for the documentation on the file format.
  5. import argparse
  6. import struct
  7. import torch
  8. from typing import Dict
  9. def parse_args():
  10. parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file')
  11. parser.add_argument('src_path', help='Path to PyTorch checkpoint file')
  12. parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten')
  13. parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32')
  14. return parser.parse_args()
  15. def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
  16. n_layer = 0
  17. while f'blocks.{n_layer}.ln1.weight' in state_dict:
  18. n_layer += 1
  19. assert n_layer > 0
  20. return n_layer
  21. def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None:
  22. emb_weight: torch.Tensor = state_dict['emb.weight']
  23. n_layer = get_layer_count(state_dict)
  24. n_vocab = emb_weight.shape[0]
  25. n_embed = emb_weight.shape[1]
  26. with open(dest_path, 'wb') as out_file:
  27. out_file.write(struct.pack(
  28. # Disable padding with '='
  29. '=iiiiii',
  30. # Magic: 'ggmf' in hex
  31. 0x67676d66,
  32. 101,
  33. n_vocab,
  34. n_embed,
  35. n_layer,
  36. 1 if data_type == 'float16' else 0
  37. ))
  38. for k in state_dict.keys():
  39. tensor = state_dict[k].float()
  40. # Same processing as in "RWKV_in_150_lines.py"
  41. if '.time_' in k:
  42. # (1, 1, n_embed) -> (n_embed)
  43. tensor = tensor.squeeze()
  44. if '.time_decay' in k:
  45. tensor = -torch.exp(tensor)
  46. # Keep 1-dim vectors in fp32
  47. if data_type == 'float16' and len(tensor.shape) > 1:
  48. tensor = tensor.half()
  49. shape = tensor.shape
  50. print(f'Writing {k}, shape {shape}, type {tensor.dtype}')
  51. k_encoded: bytes = k.encode('utf-8')
  52. out_file.write(struct.pack(
  53. '=iii',
  54. len(shape),
  55. len(k_encoded),
  56. 1 if tensor.dtype == torch.float16 else 0
  57. ))
  58. # Dimension order is reversed here:
  59. # * PyTorch shape is (x rows, y columns)
  60. # * ggml shape is (y elements in a row, x elements in a column)
  61. # Both shapes represent the same tensor.
  62. for dim in reversed(tensor.shape):
  63. out_file.write(struct.pack('=i', dim))
  64. out_file.write(k_encoded)
  65. tensor.numpy().tofile(out_file)
  66. def main() -> None:
  67. args = parse_args()
  68. print(f'Reading {args.src_path}')
  69. state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location='cpu')
  70. write_state_dict(state_dict, args.dest_path, args.data_type)
  71. print('Done')
  72. if __name__ == "__main__":
  73. main()
粤ICP备19079148号