integrated world tokenizer for RWKV
This commit is contained in:
parent
9830871d0f
commit
871009dfab
5 changed files with 131202 additions and 7 deletions
|
@ -409,8 +409,35 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
//start loading the models first
|
||||
bool useWorldTokenizer = false;
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
}
|
||||
else //rwkv_2
|
||||
{
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
||||
const size_t n_vocab = header.n_vocab;
|
||||
printf("\nDetected Vocab: %d",n_vocab);
|
||||
if(n_vocab>60000)
|
||||
{
|
||||
printf("\nUsing WORLD TOKENIZER");
|
||||
useWorldTokenizer = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string word;
|
||||
read_rwkv_vocab();
|
||||
if(useWorldTokenizer)
|
||||
{
|
||||
read_rwkv_world_vocab();
|
||||
}
|
||||
else
|
||||
{
|
||||
read_rwkv_vocab();
|
||||
}
|
||||
|
||||
int vocabsiz = rwkv_vocab.size();
|
||||
for (int i = 0; i < vocabsiz; i++)
|
||||
{
|
||||
|
@ -425,7 +452,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
n_batch = 1;
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
|
@ -454,7 +480,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
else
|
||||
{
|
||||
n_batch = 10; //use sequence mode to speedup
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
|
|
|
@ -18,11 +18,21 @@ static void replaceAll(std::string& str, const std::string& from, const std::str
|
|||
}
|
||||
}
|
||||
|
||||
static std::string hexToUnicode(const std::string& hexString) {
|
||||
std::string unicodeString;
|
||||
for (size_t i = 0; i < hexString.length(); i += 2) {
|
||||
std::string byteString = hexString.substr(i, 2);
|
||||
unsigned int byteValue = std::stoi(byteString, nullptr, 16);
|
||||
unicodeString += static_cast<char>(byteValue);
|
||||
}
|
||||
return unicodeString;
|
||||
}
|
||||
|
||||
void read_rwkv_vocab()
|
||||
{
|
||||
std::string line;
|
||||
auto filepath = executable_path+ "rwkv_vocab.embd";
|
||||
printf("Reading vocab from %s",filepath.c_str());
|
||||
printf("\nReading vocab from %s",filepath.c_str());
|
||||
std::ifstream myfile(filepath);
|
||||
if (myfile.is_open())
|
||||
{
|
||||
|
@ -46,3 +56,32 @@ void read_rwkv_vocab()
|
|||
std::cout << "Unable to open RWKV vocab file";
|
||||
}
|
||||
}
|
||||
|
||||
void read_rwkv_world_vocab() //its in hexadecimal
|
||||
{
|
||||
std::string line;
|
||||
std::string unicodeString;
|
||||
auto filepath = executable_path+ "rwkv_world_vocab.embd";
|
||||
printf("\nReading world vocab from %s",filepath.c_str());
|
||||
std::ifstream myfile(filepath);
|
||||
if (myfile.is_open())
|
||||
{
|
||||
int slen = special.size();
|
||||
int idx = 0;
|
||||
rwkv_vocab.push_back("<<UNUSED_TOKEN>>");
|
||||
while (myfile.good())
|
||||
{
|
||||
getline(myfile, line);
|
||||
unicodeString = hexToUnicode(line);
|
||||
// printf("\n%d: %s",idx,unicodeString.c_str());
|
||||
rwkv_vocab.push_back(unicodeString);
|
||||
++idx;
|
||||
}
|
||||
myfile.close();
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
std::cout << "Unable to open RWKV world vocab file";
|
||||
}
|
||||
}
|
73
otherarch/tools/rwkv_prepare_vocab_world.py
Normal file
73
otherarch/tools/rwkv_prepare_vocab_world.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import json,os
|
||||
|
||||
special = []
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
global special
|
||||
special = bs
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
def get_code_points(string):
|
||||
code_points = []
|
||||
for char in string:
|
||||
if ord(char) <= 255:
|
||||
if ord(char) in special:
|
||||
code_points.append(char)
|
||||
else:
|
||||
t = ("\\u" + format(ord(char+255), "04x"))
|
||||
code_points.append(t.decode('utf-8','ignore'))
|
||||
else:
|
||||
code_points.append("\\u" + format(ord(char), "04x"))
|
||||
return "".join(code_points)
|
||||
|
||||
import unicodedata
|
||||
|
||||
def remove_nonprintable_characters(input_string):
|
||||
cleaned_string = ''.join(
|
||||
c for c in input_string
|
||||
if unicodedata.category(c)[0] != 'C'
|
||||
)
|
||||
return cleaned_string
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
|
||||
tr = "{"
|
||||
for i in sortedbd:
|
||||
tr += "\""+i[0]+"\","
|
||||
tr += "}"
|
||||
print(tr)
|
||||
|
||||
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_world_vocab.txt", "r", encoding="utf-8") as f:
|
||||
list = f.readlines()
|
||||
s = ""
|
||||
with open("rwkv_world_vocab.embd", "w", encoding="utf-8") as f2:
|
||||
nn = 0
|
||||
for l in list:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
#dec = str(remove_nonprintable_characters(x.decode('ansi','ignore')))
|
||||
# print(str(x))
|
||||
s += x.hex() +"\n"
|
||||
f2.write(s)
|
||||
|
||||
print("OK")
|
65529
otherarch/tools/rwkv_world_vocab.txt
Normal file
65529
otherarch/tools/rwkv_world_vocab.txt
Normal file
File diff suppressed because it is too large
Load diff
65529
rwkv_world_vocab.embd
Normal file
65529
rwkv_world_vocab.embd
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue