rwkv is done
This commit is contained in:
parent
a76b15b581
commit
ea01771dd5
5 changed files with 62 additions and 12 deletions
4
Makefile
4
Makefile
|
@ -130,8 +130,8 @@ ifdef LLAMA_GPROF
|
||||||
CXXFLAGS += -pg
|
CXXFLAGS += -pg
|
||||||
endif
|
endif
|
||||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||||
CFLAGS += -mcpu=native
|
CFLAGS +=
|
||||||
CXXFLAGS += -mcpu=native
|
CXXFLAGS +=
|
||||||
endif
|
endif
|
||||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||||
# Raspberry Pi 1, 2, 3
|
# Raspberry Pi 1, 2, 3
|
||||||
|
|
|
@ -232,8 +232,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
||||||
print_tok_vec(embd_inp);
|
|
||||||
|
|
||||||
//truncate to front of the prompt if its too long
|
//truncate to front of the prompt if its too long
|
||||||
int32_t nctx = params.n_ctx;
|
int32_t nctx = params.n_ctx;
|
||||||
|
|
||||||
|
@ -258,7 +257,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
|
|
||||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||||
bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
|
bool approved_format = (file_format==FileFormat::GPT2_2 || file_format==FileFormat::GPTJ_3);
|
||||||
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
||||||
// bool blasmode = false;
|
// bool blasmode = false;
|
||||||
int original_batch = params.n_batch;
|
int original_batch = params.n_batch;
|
||||||
|
@ -304,6 +303,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
else if(file_format == FileFormat::RWKV_1)
|
else if(file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||||
|
rwkv_context_v1->state_in = nullptr;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -333,9 +333,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
|
|
||||||
if(file_format==FileFormat::RWKV_1)
|
if(file_format==FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
printf("\nsiz:%d val:%d\n",embd.size(),embd[0]);
|
|
||||||
evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out);
|
evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out);
|
||||||
memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size());
|
memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size());
|
||||||
|
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
|
||||||
}
|
}
|
||||||
else if(file_format==FileFormat::GPT2_1)
|
else if(file_format==FileFormat::GPT2_1)
|
||||||
{
|
{
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -6,6 +6,17 @@
|
||||||
#include "expose.h"
|
#include "expose.h"
|
||||||
|
|
||||||
std::vector<std::string> rwkv_vocab;
|
std::vector<std::string> rwkv_vocab;
|
||||||
|
std::vector<std::string> special = {"Ā","ā","Ă","ă","Ą","ą","Ć","ć","Ĉ","ĉ","Ċ","ċ","Č","č","Ď","ď","Đ","đ","Ē","ē","Ĕ","ĕ","Ė","ė","Ę","ę","Ě","ě","Ĝ","ĝ","Ğ","ğ","Ġ","!","\"","#","$","%","&","\'","(",")","*","+",",","-",".","/","0","1","2","3","4","5","6","7","8","9",":",";","<","=",">","?","@","A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z","[","\\","]","^","_","`","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z","{","|","}","~","ġ","Ģ","ģ","Ĥ","ĥ","Ħ","ħ","Ĩ","ĩ","Ī","ī","Ĭ","ĭ","Į","į","İ","ı","IJ","ij","Ĵ","ĵ","Ķ","ķ","ĸ","Ĺ","ĺ","Ļ","ļ","Ľ","ľ","Ŀ","ŀ","Ł","ł","¡","¢","£","¤","¥","¦","§","¨","©","ª","«","¬","Ń","®","¯","°","±","²","³","´","µ","¶","·","¸","¹","º","»","¼","½","¾","¿","À","Á","Â","Ã","Ä","Å","Æ","Ç","È","É","Ê","Ë","Ì","Í","Î","Ï","Ð","Ñ","Ò","Ó","Ô","Õ","Ö","×","Ø","Ù","Ú","Û","Ü","Ý","Þ","ß","à","á","â","ã","ä","å","æ","ç","è","é","ê","ë","ì","í","î","ï","ð","ñ","ò","ó","ô","õ","ö","÷","ø","ù","ú","û","ü","ý","þ","ÿ"};
|
||||||
|
|
||||||
|
static void replaceAll(std::string& str, const std::string& from, const std::string& to) {
|
||||||
|
if(from.empty())
|
||||||
|
return;
|
||||||
|
size_t start_pos = 0;
|
||||||
|
while((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||||
|
str.replace(start_pos, from.length(), to);
|
||||||
|
start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void read_rwkv_vocab()
|
void read_rwkv_vocab()
|
||||||
{
|
{
|
||||||
|
@ -15,10 +26,17 @@ void read_rwkv_vocab()
|
||||||
std::ifstream myfile(filepath);
|
std::ifstream myfile(filepath);
|
||||||
if (myfile.is_open())
|
if (myfile.is_open())
|
||||||
{
|
{
|
||||||
|
int slen = special.size();
|
||||||
while (myfile.good())
|
while (myfile.good())
|
||||||
{
|
{
|
||||||
getline(myfile, line);
|
getline(myfile, line);
|
||||||
rwkv_vocab.push_back(line);
|
for(int i=0;i<slen;++i)
|
||||||
|
{
|
||||||
|
std::string swapped = "";
|
||||||
|
swapped.push_back((char)i);
|
||||||
|
replaceAll(line,special[i],swapped);
|
||||||
|
}
|
||||||
|
rwkv_vocab.push_back(line);
|
||||||
}
|
}
|
||||||
myfile.close();
|
myfile.close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,45 @@
|
||||||
import json
|
import json,os
|
||||||
with open("rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
|
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8+n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||||
|
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
|
||||||
|
tr = "{"
|
||||||
|
for i in sortedbd:
|
||||||
|
tr += "\""+i[0]+"\","
|
||||||
|
tr += "}"
|
||||||
|
print(tr)
|
||||||
|
|
||||||
|
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
|
||||||
encoder = json.load(f)
|
encoder = json.load(f)
|
||||||
s = ""
|
s = ""
|
||||||
with open("rwkv_vocab.embd", "w", encoding="utf-8") as f2:
|
with open("rwkv_vocab.embd", "w", encoding="utf-8") as f2:
|
||||||
for key in encoder:
|
for key in encoder:
|
||||||
|
#key = bytearray([byte_decoder[c] for c in key]).decode('utf-8','ignore')
|
||||||
# key = key.replace("\\","\\\\")
|
# key = key.replace("\\","\\\\")
|
||||||
# key = key.replace("\"","\\\"")
|
# key = key.replace("\"","\\\"")
|
||||||
# s += "\""+key+"\",\n"
|
# s += "\""+key+"\",\n"
|
||||||
s += key +"\n"
|
s += key +"\n"
|
||||||
f2.write(s)
|
f2.write(s)
|
||||||
|
|
||||||
print("OK")
|
print("OK")
|
Loading…
Add table
Add a link
Reference in a new issue