rwkv_tokenizer.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import tokenizers
  3. import pathlib
  4. ########################################################################################################
  5. # Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
  6. ########################################################################################################
  7. class TRIE:
  8. __slots__ = tuple("ch,to,values,front".split(","))
  9. to:list
  10. values:set
  11. def __init__(self, front=None, ch=None):
  12. self.ch = ch
  13. self.to = [None for ch in range(256)]
  14. self.values = set()
  15. self.front = front
  16. def __repr__(self):
  17. fr = self
  18. ret = []
  19. while(fr!=None):
  20. if(fr.ch!=None):
  21. ret.append(fr.ch)
  22. fr = fr.front
  23. return "<TRIE %s %s>"%(ret[::-1], self.values)
  24. def add(self, key:bytes, idx:int=0, val=None):
  25. if(idx == len(key)):
  26. if(val is None):
  27. val = key
  28. self.values.add(val)
  29. return self
  30. ch = key[idx]
  31. if(self.to[ch] is None):
  32. self.to[ch] = TRIE(front=self, ch=ch)
  33. return self.to[ch].add(key, idx=idx+1, val=val)
  34. def find_longest(self, key:bytes, idx:int=0):
  35. u:TRIE = self
  36. ch:int = key[idx]
  37. while(u.to[ch] is not None):
  38. u = u.to[ch]
  39. idx += 1
  40. if(u.values):
  41. ret = idx, u, u.values
  42. if(idx==len(key)):
  43. break
  44. ch = key[idx]
  45. return ret
  46. class TRIE_TOKENIZER():
  47. def __init__(self, file_name):
  48. self.idx2token = {}
  49. sorted = [] # must be already sorted
  50. with open(file_name, "r", encoding="utf-8") as f:
  51. lines = f.readlines()
  52. for l in lines:
  53. idx = int(l[:l.index(' ')])
  54. x = eval(l[l.index(' '):l.rindex(' ')])
  55. x = x.encode("utf-8") if isinstance(x, str) else x
  56. assert isinstance(x, bytes)
  57. assert len(x) == int(l[l.rindex(' '):])
  58. sorted += [x]
  59. self.idx2token[idx] = x
  60. self.token2idx = {}
  61. for k,v in self.idx2token.items():
  62. self.token2idx[v] = int(k)
  63. self.root = TRIE()
  64. for t, i in self.token2idx.items():
  65. _ = self.root.add(t, val=(t, i))
  66. def encodeBytes(self, src:bytes) -> list[int]:
  67. idx:int = 0
  68. tokens:list[int] = []
  69. while (idx < len(src)):
  70. _idx:int = idx
  71. idx, _, values = self.root.find_longest(src, idx)
  72. assert(idx != _idx)
  73. _, token = next(iter(values))
  74. tokens.append(token)
  75. return tokens
  76. def decodeBytes(self, tokens):
  77. return b''.join(map(lambda i: self.idx2token[i], tokens))
  78. def encode(self, src):
  79. return self.encodeBytes(src.encode("utf-8"))
  80. def decode(self, tokens):
  81. return self.decodeBytes(tokens).decode('utf-8','replace')
  82. def printTokens(self, tokens):
  83. for i in tokens:
  84. s = self.idx2token[i]
  85. try:
  86. s = s.decode('utf-8')
  87. except:
  88. pass
  89. print(f'{repr(s)}{i}', end=' ')
  90. print()
  91. def get_tokenizer(tokenizer="20B"):
  92. if tokenizer == "world":
  93. print('Loading world tokenizer')
  94. tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / 'rwkv_vocab_v20230424.txt'
  95. tokenizer = TRIE_TOKENIZER(tokenizer_path)
  96. tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
  97. elif tokenizer == "20B":
  98. print('Loading 20B tokenizer')
  99. tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
  100. tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
  101. tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
  102. else:
  103. print(f"Unknown tokenizer: {args.tokenizer}")
  104. quit()
  105. return tokenizer, tokenizer_encode
粤ICP备19079148号