rwkv_cpp_model.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import os
  2. import torch
  3. import multiprocessing
  4. from llms.rwkvcpp.rwkv_cpp_shared_library import RWKVSharedLibrary
  5. from typing import Tuple, Optional
  6. class RWKVModel:
  7. """
  8. PyTorch wrapper around rwkv.cpp model.
  9. """
  10. def __init__(
  11. self,
  12. shared_library: RWKVSharedLibrary,
  13. model_path: str,
  14. thread_count: int = max(1, multiprocessing.cpu_count() // 2),
  15. gpu_layers_count: int = 4,
  16. ):
  17. """
  18. Loads the model and prepares it for inference.
  19. In case of any error, this method will throw an exception.
  20. Parameters
  21. ----------
  22. shared_library : RWKVSharedLibrary
  23. rwkv.cpp shared library.
  24. model_path : str
  25. Path to RWKV model file in ggml format.
  26. thread_count : int
  27. Thread count to use. If not set, defaults to CPU count / 2.
  28. """
  29. assert os.path.isfile(model_path), f'{model_path} is not a file'
  30. assert thread_count > 0, 'Thread count must be positive'
  31. assert gpu_layers_count >= 0, 'GPU layers count must be >= 0'
  32. self._library = shared_library
  33. self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
  34. if gpu_layers_count > 0:
  35. self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count)
  36. self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
  37. self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
  38. self._valid = True
  39. def eval(
  40. self,
  41. token: int,
  42. state_in: Optional[torch.Tensor],
  43. state_out: Optional[torch.Tensor] = None,
  44. logits_out: Optional[torch.Tensor] = None
  45. ) -> Tuple[torch.Tensor, torch.Tensor]:
  46. """
  47. Evaluates the model for a single token.
  48. In case of any error, this method will throw an exception.
  49. Parameters
  50. ----------
  51. token : int
  52. Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
  53. state_in : Optional[torch.Tensor]
  54. State from previous call of this method. If this is a first pass, set it to None.
  55. state_out : Optional[torch.Tensor]
  56. Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
  57. logits_out : Optional[torch.Tensor]
  58. Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
  59. Returns
  60. -------
  61. logits, state
  62. Logits vector of shape (n_vocab); state for the next step.
  63. """
  64. assert self._valid, 'Model was freed'
  65. def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
  66. assert buf.dtype == torch.float32, f'{name} is not of type float32'
  67. assert buf.is_contiguous(), f'{name} is not contiguous'
  68. assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
  69. if state_in is not None:
  70. validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
  71. state_in_ptr = state_in.data_ptr()
  72. else:
  73. state_in_ptr = 0
  74. if state_out is not None:
  75. validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
  76. else:
  77. state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
  78. if logits_out is not None:
  79. validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
  80. else:
  81. logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
  82. self._library.rwkv_eval(
  83. self._ctx,
  84. token,
  85. state_in_ptr,
  86. state_out.data_ptr(),
  87. logits_out.data_ptr()
  88. )
  89. return logits_out, state_out
  90. def free(self):
  91. """
  92. Frees all allocated resources.
  93. In case of any error, this method will throw an exception.
  94. The object must not be used anymore after calling this method.
  95. """
  96. assert self._valid, 'Already freed'
  97. self._valid = False
  98. self._library.rwkv_free(self._ctx)
  99. def __del__(self):
  100. # Free the context on GC in case user forgot to call free() explicitly.
  101. if hasattr(self, '_valid') and self._valid:
  102. self.free()
粤ICP备19079148号