integrated world tokenizer for RWKV

This commit is contained in:
Concedo 2023-06-13 20:06:19 +08:00
parent 9830871d0f
commit 871009dfab
5 changed files with 131202 additions and 7 deletions

View file

@ -409,8 +409,35 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
//start loading the models first
bool useWorldTokenizer = false;
if (file_format == FileFormat::RWKV_1)
{
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
}
else //rwkv_2
{
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
const size_t n_vocab = header.n_vocab;
printf("\nDetected Vocab: %d",n_vocab);
if(n_vocab>60000)
{
printf("\nUsing WORLD TOKENIZER");
useWorldTokenizer = true;
}
}
std::string word;
read_rwkv_vocab();
if(useWorldTokenizer)
{
read_rwkv_world_vocab();
}
else
{
read_rwkv_vocab();
}
int vocabsiz = rwkv_vocab.size();
for (int i = 0; i < vocabsiz; i++)
{
@ -425,7 +452,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
if (file_format == FileFormat::RWKV_1)
{
n_batch = 1;
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;
@ -454,7 +480,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
else
{
n_batch = 10; //use sequence mode to speedup
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;

View file

@ -18,11 +18,21 @@ static void replaceAll(std::string& str, const std::string& from, const std::str
}
}
static std::string hexToUnicode(const std::string& hexString) {
std::string unicodeString;
for (size_t i = 0; i < hexString.length(); i += 2) {
std::string byteString = hexString.substr(i, 2);
unsigned int byteValue = std::stoi(byteString, nullptr, 16);
unicodeString += static_cast<char>(byteValue);
}
return unicodeString;
}
void read_rwkv_vocab()
{
std::string line;
auto filepath = executable_path+ "rwkv_vocab.embd";
printf("Reading vocab from %s",filepath.c_str());
printf("\nReading vocab from %s",filepath.c_str());
std::ifstream myfile(filepath);
if (myfile.is_open())
{
@ -46,3 +56,32 @@ void read_rwkv_vocab()
std::cout << "Unable to open RWKV vocab file";
}
}
void read_rwkv_world_vocab() //its in hexadecimal
{
std::string line;
std::string unicodeString;
auto filepath = executable_path+ "rwkv_world_vocab.embd";
printf("\nReading world vocab from %s",filepath.c_str());
std::ifstream myfile(filepath);
if (myfile.is_open())
{
int slen = special.size();
int idx = 0;
rwkv_vocab.push_back("<<UNUSED_TOKEN>>");
while (myfile.good())
{
getline(myfile, line);
unicodeString = hexToUnicode(line);
// printf("\n%d: %s",idx,unicodeString.c_str());
rwkv_vocab.push_back(unicodeString);
++idx;
}
myfile.close();
}
else
{
std::cout << "Unable to open RWKV world vocab file";
}
}

View file

@ -0,0 +1,73 @@
import json,os
special = []
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
global special
special = bs
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_code_points(string):
code_points = []
for char in string:
if ord(char) <= 255:
if ord(char) in special:
code_points.append(char)
else:
t = ("\\u" + format(ord(char+255), "04x"))
code_points.append(t.decode('utf-8','ignore'))
else:
code_points.append("\\u" + format(ord(char), "04x"))
return "".join(code_points)
import unicodedata
def remove_nonprintable_characters(input_string):
cleaned_string = ''.join(
c for c in input_string
if unicodedata.category(c)[0] != 'C'
)
return cleaned_string
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
tr = "{"
for i in sortedbd:
tr += "\""+i[0]+"\","
tr += "}"
print(tr)
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_world_vocab.txt", "r", encoding="utf-8") as f:
list = f.readlines()
s = ""
with open("rwkv_world_vocab.embd", "w", encoding="utf-8") as f2:
nn = 0
for l in list:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
#dec = str(remove_nonprintable_characters(x.decode('ansi','ignore')))
# print(str(x))
s += x.hex() +"\n"
f2.write(s)
print("OK")

File diff suppressed because it is too large Load diff

65529
rwkv_world_vocab.embd Normal file

File diff suppressed because it is too large Load diff