rwkv_cpp_shared_library.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import os
  2. import sys
  3. import ctypes
  4. import pathlib
  5. from typing import Optional
  6. QUANTIZED_FORMAT_NAMES = (
  7. 'Q4_0',
  8. 'Q4_1',
  9. 'Q5_0',
  10. 'Q5_1',
  11. 'Q8_0'
  12. )
  13. P_FLOAT = ctypes.POINTER(ctypes.c_float)
  14. class RWKVContext:
  15. def __init__(self, ptr: ctypes.pointer):
  16. self.ptr = ptr
  17. class RWKVSharedLibrary:
  18. """
  19. Python wrapper around rwkv.cpp shared library.
  20. """
  21. def __init__(self, shared_library_path: str):
  22. """
  23. Loads the shared library from specified file.
  24. In case of any error, this method will throw an exception.
  25. Parameters
  26. ----------
  27. shared_library_path : str
  28. Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
  29. """
  30. self.library = ctypes.cdll.LoadLibrary(shared_library_path)
  31. self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
  32. self.library.rwkv_init_from_file.restype = ctypes.c_void_p
  33. self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
  34. self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
  35. self.library.rwkv_eval.argtypes = [
  36. ctypes.c_void_p, # ctx
  37. ctypes.c_int32, # token
  38. P_FLOAT, # state_in
  39. P_FLOAT, # state_out
  40. P_FLOAT # logits_out
  41. ]
  42. self.library.rwkv_eval.restype = ctypes.c_bool
  43. self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
  44. self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
  45. self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
  46. self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
  47. self.library.rwkv_free.argtypes = [ctypes.c_void_p]
  48. self.library.rwkv_free.restype = None
  49. self.library.rwkv_free.argtypes = [ctypes.c_void_p]
  50. self.library.rwkv_free.restype = None
  51. self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
  52. self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
  53. self.library.rwkv_get_system_info_string.argtypes = []
  54. self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
  55. def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
  56. """
  57. Loads the model from a file and prepares it for inference.
  58. Throws an exception in case of any error. Error messages would be printed to stderr.
  59. Parameters
  60. ----------
  61. model_file_path : str
  62. Path to model file in ggml format.
  63. thread_count : int
  64. Count of threads to use, must be positive.
  65. gpu_layers_count : int
  66. Count of layers to load on gpu, must be positive only enabled with cuBLAS.
  67. """
  68. ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'),
  69. ctypes.c_uint32(thread_count))
  70. assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
  71. return RWKVContext(ptr)
  72. def rwkv_gpu_offload_layers(self, ctx: RWKVContext, gpu_layers_count: int) -> None:
  73. """
  74. Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
  75. If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
  76. Parameters
  77. ----------
  78. gpu_layers_count : int
  79. Count of layers to load onto gpu, must be >= 0, only enabled with cuBLAS.
  80. """
  81. assert self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(gpu_layers_count)), 'rwkv_gpu_offload_layers failed, check stderr'
  82. def rwkv_eval(
  83. self,
  84. ctx: RWKVContext,
  85. token: int,
  86. state_in_address: Optional[int],
  87. state_out_address: int,
  88. logits_out_address: int
  89. ) -> None:
  90. """
  91. Evaluates the model for a single token.
  92. Throws an exception in case of any error. Error messages would be printed to stderr.
  93. Parameters
  94. ----------
  95. ctx : RWKVContext
  96. RWKV context obtained from rwkv_init_from_file.
  97. token : int
  98. Next token index, in range 0 <= token < n_vocab.
  99. state_in_address : int
  100. Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
  101. state_out_address : int
  102. Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
  103. logits_out_address : int
  104. Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
  105. """
  106. assert self.library.rwkv_eval(
  107. ctx.ptr,
  108. ctypes.c_int32(token),
  109. ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
  110. ctypes.cast(state_out_address, P_FLOAT),
  111. ctypes.cast(logits_out_address, P_FLOAT)
  112. ), 'rwkv_eval failed, check stderr'
  113. def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
  114. """
  115. Returns count of FP32 elements in state buffer.
  116. Parameters
  117. ----------
  118. ctx : RWKVContext
  119. RWKV context obtained from rwkv_init_from_file.
  120. """
  121. return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
  122. def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
  123. """
  124. Returns count of FP32 elements in logits buffer.
  125. Parameters
  126. ----------
  127. ctx : RWKVContext
  128. RWKV context obtained from rwkv_init_from_file.
  129. """
  130. return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
  131. def rwkv_free(self, ctx: RWKVContext) -> None:
  132. """
  133. Frees all allocated memory and the context.
  134. Parameters
  135. ----------
  136. ctx : RWKVContext
  137. RWKV context obtained from rwkv_init_from_file.
  138. """
  139. self.library.rwkv_free(ctx.ptr)
  140. ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
  141. def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
  142. """
  143. Quantizes FP32 or FP16 model to one of INT4 formats.
  144. Throws an exception in case of any error. Error messages would be printed to stderr.
  145. Parameters
  146. ----------
  147. model_file_path_in : str
  148. Path to model file in ggml format, must be either FP32 or FP16.
  149. model_file_path_out : str
  150. Quantized model will be written here.
  151. format_name : str
  152. One of QUANTIZED_FORMAT_NAMES.
  153. """
  154. assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
  155. assert self.library.rwkv_quantize_model_file(
  156. model_file_path_in.encode('utf-8'),
  157. model_file_path_out.encode('utf-8'),
  158. format_name.encode('utf-8')
  159. ), 'rwkv_quantize_model_file failed, check stderr'
  160. def rwkv_get_system_info_string(self) -> str:
  161. """
  162. Returns system information string.
  163. """
  164. return self.library.rwkv_get_system_info_string().decode('utf-8')
  165. def load_rwkv_shared_library() -> RWKVSharedLibrary:
  166. """
  167. Attempts to find rwkv.cpp shared library and load it.
  168. To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
  169. """
  170. file_name: str
  171. if 'win32' in sys.platform or 'cygwin' in sys.platform:
  172. file_name = 'rwkv.dll'
  173. elif 'darwin' in sys.platform:
  174. file_name = 'librwkv.dylib'
  175. else:
  176. file_name = 'librwkv.so'
  177. repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
  178. paths = [
  179. # If we are in "rwkv" directory
  180. f'../bin/Release/{file_name}',
  181. # If we are in repo root directory
  182. f'bin/Release/{file_name}',
  183. # If we compiled in build directory
  184. f'build/bin/Release/{file_name}',
  185. # If we compiled in build directory
  186. f'build/{file_name}',
  187. # Search relative to this file
  188. str(repo_root_dir / 'bin' / 'Release' / file_name),
  189. # Fallback
  190. str(repo_root_dir / 'rwkvcpp' / file_name),
  191. ]
  192. for path in paths:
  193. if os.path.isfile(path):
  194. return RWKVSharedLibrary(path)
  195. return RWKVSharedLibrary(paths[-1])
粤ICP备19079148号