integrated world tokenizer for RWKV
This commit is contained in:
parent
9830871d0f
commit
871009dfab
5 changed files with 131202 additions and 7 deletions
|
@ -409,8 +409,35 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
|
//start loading the models first
|
||||||
|
bool useWorldTokenizer = false;
|
||||||
|
if (file_format == FileFormat::RWKV_1)
|
||||||
|
{
|
||||||
|
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
}
|
||||||
|
else //rwkv_2
|
||||||
|
{
|
||||||
|
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
||||||
|
const size_t n_vocab = header.n_vocab;
|
||||||
|
printf("\nDetected Vocab: %d",n_vocab);
|
||||||
|
if(n_vocab>60000)
|
||||||
|
{
|
||||||
|
printf("\nUsing WORLD TOKENIZER");
|
||||||
|
useWorldTokenizer = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string word;
|
std::string word;
|
||||||
read_rwkv_vocab();
|
if(useWorldTokenizer)
|
||||||
|
{
|
||||||
|
read_rwkv_world_vocab();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
read_rwkv_vocab();
|
||||||
|
}
|
||||||
|
|
||||||
int vocabsiz = rwkv_vocab.size();
|
int vocabsiz = rwkv_vocab.size();
|
||||||
for (int i = 0; i < vocabsiz; i++)
|
for (int i = 0; i < vocabsiz; i++)
|
||||||
{
|
{
|
||||||
|
@ -425,7 +452,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
if (file_format == FileFormat::RWKV_1)
|
if (file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
n_batch = 1;
|
n_batch = 1;
|
||||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
|
||||||
|
|
||||||
//setup buffers for rwkv state
|
//setup buffers for rwkv state
|
||||||
auto padding = 512u;
|
auto padding = 512u;
|
||||||
|
@ -454,7 +480,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
n_batch = 10; //use sequence mode to speedup
|
n_batch = 10; //use sequence mode to speedup
|
||||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
|
||||||
|
|
||||||
//setup buffers for rwkv state
|
//setup buffers for rwkv state
|
||||||
auto padding = 512u;
|
auto padding = 512u;
|
||||||
|
|
|
@ -18,25 +18,35 @@ static void replaceAll(std::string& str, const std::string& from, const std::str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string hexToUnicode(const std::string& hexString) {
|
||||||
|
std::string unicodeString;
|
||||||
|
for (size_t i = 0; i < hexString.length(); i += 2) {
|
||||||
|
std::string byteString = hexString.substr(i, 2);
|
||||||
|
unsigned int byteValue = std::stoi(byteString, nullptr, 16);
|
||||||
|
unicodeString += static_cast<char>(byteValue);
|
||||||
|
}
|
||||||
|
return unicodeString;
|
||||||
|
}
|
||||||
|
|
||||||
void read_rwkv_vocab()
|
void read_rwkv_vocab()
|
||||||
{
|
{
|
||||||
std::string line;
|
std::string line;
|
||||||
auto filepath = executable_path+ "rwkv_vocab.embd";
|
auto filepath = executable_path+ "rwkv_vocab.embd";
|
||||||
printf("Reading vocab from %s",filepath.c_str());
|
printf("\nReading vocab from %s",filepath.c_str());
|
||||||
std::ifstream myfile(filepath);
|
std::ifstream myfile(filepath);
|
||||||
if (myfile.is_open())
|
if (myfile.is_open())
|
||||||
{
|
{
|
||||||
int slen = special.size();
|
int slen = special.size();
|
||||||
while (myfile.good())
|
while (myfile.good())
|
||||||
{
|
{
|
||||||
getline(myfile, line);
|
getline(myfile, line);
|
||||||
for(int i=0;i<slen;++i)
|
for(int i=0;i<slen;++i)
|
||||||
{
|
{
|
||||||
std::string swapped = "";
|
std::string swapped = "";
|
||||||
swapped.push_back((char)i);
|
swapped.push_back((char)i);
|
||||||
replaceAll(line,special[i],swapped);
|
replaceAll(line,special[i],swapped);
|
||||||
}
|
}
|
||||||
rwkv_vocab.push_back(line);
|
rwkv_vocab.push_back(line);
|
||||||
}
|
}
|
||||||
myfile.close();
|
myfile.close();
|
||||||
}
|
}
|
||||||
|
@ -45,4 +55,33 @@ void read_rwkv_vocab()
|
||||||
{
|
{
|
||||||
std::cout << "Unable to open RWKV vocab file";
|
std::cout << "Unable to open RWKV vocab file";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_rwkv_world_vocab() //its in hexadecimal
|
||||||
|
{
|
||||||
|
std::string line;
|
||||||
|
std::string unicodeString;
|
||||||
|
auto filepath = executable_path+ "rwkv_world_vocab.embd";
|
||||||
|
printf("\nReading world vocab from %s",filepath.c_str());
|
||||||
|
std::ifstream myfile(filepath);
|
||||||
|
if (myfile.is_open())
|
||||||
|
{
|
||||||
|
int slen = special.size();
|
||||||
|
int idx = 0;
|
||||||
|
rwkv_vocab.push_back("<<UNUSED_TOKEN>>");
|
||||||
|
while (myfile.good())
|
||||||
|
{
|
||||||
|
getline(myfile, line);
|
||||||
|
unicodeString = hexToUnicode(line);
|
||||||
|
// printf("\n%d: %s",idx,unicodeString.c_str());
|
||||||
|
rwkv_vocab.push_back(unicodeString);
|
||||||
|
++idx;
|
||||||
|
}
|
||||||
|
myfile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cout << "Unable to open RWKV world vocab file";
|
||||||
|
}
|
||||||
}
|
}
|
73
otherarch/tools/rwkv_prepare_vocab_world.py
Normal file
73
otherarch/tools/rwkv_prepare_vocab_world.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import json,os
|
||||||
|
|
||||||
|
special = []
|
||||||
|
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
|
global special
|
||||||
|
special = bs
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8+n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
def get_code_points(string):
|
||||||
|
code_points = []
|
||||||
|
for char in string:
|
||||||
|
if ord(char) <= 255:
|
||||||
|
if ord(char) in special:
|
||||||
|
code_points.append(char)
|
||||||
|
else:
|
||||||
|
t = ("\\u" + format(ord(char+255), "04x"))
|
||||||
|
code_points.append(t.decode('utf-8','ignore'))
|
||||||
|
else:
|
||||||
|
code_points.append("\\u" + format(ord(char), "04x"))
|
||||||
|
return "".join(code_points)
|
||||||
|
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
def remove_nonprintable_characters(input_string):
|
||||||
|
cleaned_string = ''.join(
|
||||||
|
c for c in input_string
|
||||||
|
if unicodedata.category(c)[0] != 'C'
|
||||||
|
)
|
||||||
|
return cleaned_string
|
||||||
|
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||||
|
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
|
||||||
|
tr = "{"
|
||||||
|
for i in sortedbd:
|
||||||
|
tr += "\""+i[0]+"\","
|
||||||
|
tr += "}"
|
||||||
|
print(tr)
|
||||||
|
|
||||||
|
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_world_vocab.txt", "r", encoding="utf-8") as f:
|
||||||
|
list = f.readlines()
|
||||||
|
s = ""
|
||||||
|
with open("rwkv_world_vocab.embd", "w", encoding="utf-8") as f2:
|
||||||
|
nn = 0
|
||||||
|
for l in list:
|
||||||
|
idx = int(l[:l.index(' ')])
|
||||||
|
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||||
|
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||||
|
#dec = str(remove_nonprintable_characters(x.decode('ansi','ignore')))
|
||||||
|
# print(str(x))
|
||||||
|
s += x.hex() +"\n"
|
||||||
|
f2.write(s)
|
||||||
|
|
||||||
|
print("OK")
|
65529
otherarch/tools/rwkv_world_vocab.txt
Normal file
65529
otherarch/tools/rwkv_world_vocab.txt
Normal file
File diff suppressed because it is too large
Load diff
65529
rwkv_world_vocab.embd
Normal file
65529
rwkv_world_vocab.embd
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue