| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import os
- import tokenizers
- import pathlib
- ########################################################################################################
- # Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
- ########################################################################################################
- class TRIE:
- __slots__ = tuple("ch,to,values,front".split(","))
- to:list
- values:set
- def __init__(self, front=None, ch=None):
- self.ch = ch
- self.to = [None for ch in range(256)]
- self.values = set()
- self.front = front
- def __repr__(self):
- fr = self
- ret = []
- while(fr!=None):
- if(fr.ch!=None):
- ret.append(fr.ch)
- fr = fr.front
- return "<TRIE %s %s>"%(ret[::-1], self.values)
-
- def add(self, key:bytes, idx:int=0, val=None):
- if(idx == len(key)):
- if(val is None):
- val = key
- self.values.add(val)
- return self
- ch = key[idx]
- if(self.to[ch] is None):
- self.to[ch] = TRIE(front=self, ch=ch)
- return self.to[ch].add(key, idx=idx+1, val=val)
-
- def find_longest(self, key:bytes, idx:int=0):
- u:TRIE = self
- ch:int = key[idx]
-
- while(u.to[ch] is not None):
- u = u.to[ch]
- idx += 1
- if(u.values):
- ret = idx, u, u.values
- if(idx==len(key)):
- break
- ch = key[idx]
- return ret
- class TRIE_TOKENIZER():
- def __init__(self, file_name):
- self.idx2token = {}
- sorted = [] # must be already sorted
- with open(file_name, "r", encoding="utf-8") as f:
- lines = f.readlines()
- for l in lines:
- idx = int(l[:l.index(' ')])
- x = eval(l[l.index(' '):l.rindex(' ')])
- x = x.encode("utf-8") if isinstance(x, str) else x
- assert isinstance(x, bytes)
- assert len(x) == int(l[l.rindex(' '):])
- sorted += [x]
- self.idx2token[idx] = x
- self.token2idx = {}
- for k,v in self.idx2token.items():
- self.token2idx[v] = int(k)
- self.root = TRIE()
- for t, i in self.token2idx.items():
- _ = self.root.add(t, val=(t, i))
- def encodeBytes(self, src:bytes) -> list[int]:
- idx:int = 0
- tokens:list[int] = []
- while (idx < len(src)):
- _idx:int = idx
- idx, _, values = self.root.find_longest(src, idx)
- assert(idx != _idx)
- _, token = next(iter(values))
- tokens.append(token)
- return tokens
- def decodeBytes(self, tokens):
- return b''.join(map(lambda i: self.idx2token[i], tokens))
- def encode(self, src):
- return self.encodeBytes(src.encode("utf-8"))
- def decode(self, tokens):
- return self.decodeBytes(tokens).decode('utf-8','replace')
- def printTokens(self, tokens):
- for i in tokens:
- s = self.idx2token[i]
- try:
- s = s.decode('utf-8')
- except:
- pass
- print(f'{repr(s)}{i}', end=' ')
- print()
- def get_tokenizer(tokenizer="20B"):
- if tokenizer == "world":
- print('Loading world tokenizer')
- tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / 'rwkv_vocab_v20230424.txt'
- tokenizer = TRIE_TOKENIZER(tokenizer_path)
- tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
- elif tokenizer == "20B":
- print('Loading 20B tokenizer')
- tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
- tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
- tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
- else:
- print(f"Unknown tokenizer: {args.tokenizer}")
- quit()
- return tokenizer, tokenizer_encode
|