rwkv is done

This commit is contained in:
Concedo 2023-04-18 20:55:01 +08:00
parent a76b15b581
commit ea01771dd5
5 changed files with 62 additions and 12 deletions

View file

@ -130,8 +130,8 @@ ifdef LLAMA_GPROF
CXXFLAGS += -pg CXXFLAGS += -pg
endif endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
CFLAGS += -mcpu=native CFLAGS +=
CXXFLAGS += -mcpu=native CXXFLAGS +=
endif endif
ifneq ($(filter armv6%,$(UNAME_M)),) ifneq ($(filter armv6%,$(UNAME_M)),)
# Raspberry Pi 1, 2, 3 # Raspberry Pi 1, 2, 3

View file

@ -232,8 +232,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt // tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt); std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
print_tok_vec(embd_inp);
//truncate to front of the prompt if its too long //truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx; int32_t nctx = params.n_ctx;
@ -258,7 +257,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
} }
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch //if using BLAS and prompt is big enough, switch to single thread and use a huge batch
bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2); bool approved_format = (file_format==FileFormat::GPT2_2 || file_format==FileFormat::GPTJ_3);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas()); bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool blasmode = false; // bool blasmode = false;
int original_batch = params.n_batch; int original_batch = params.n_batch;
@ -304,6 +303,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else if(file_format == FileFormat::RWKV_1) else if(file_format == FileFormat::RWKV_1)
{ {
n_vocab = vocab.id_to_token.size(); //handled seperately n_vocab = vocab.id_to_token.size(); //handled seperately
rwkv_context_v1->state_in = nullptr;
} }
else else
{ {
@ -333,9 +333,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(file_format==FileFormat::RWKV_1) if(file_format==FileFormat::RWKV_1)
{ {
printf("\nsiz:%d val:%d\n",embd.size(),embd[0]);
evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out); evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out);
memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size()); memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size());
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
} }
else if(file_format==FileFormat::GPT2_1) else if(file_format==FileFormat::GPT2_1)
{ {

File diff suppressed because one or more lines are too long

View file

@ -6,6 +6,17 @@
#include "expose.h" #include "expose.h"
std::vector<std::string> rwkv_vocab; std::vector<std::string> rwkv_vocab;
std::vector<std::string> special = {"Ā","ā","Ă","ă","Ą","ą","Ć","ć","Ĉ","ĉ","Ċ","ċ","Č","č","Ď","ď","Đ","đ","Ē","ē","Ĕ","ĕ","Ė","ė","Ę","ę","Ě","ě","Ĝ","ĝ","Ğ","ğ","Ġ","!","\"","#","$","%","&","\'","(",")","*","+",",","-",".","/","0","1","2","3","4","5","6","7","8","9",":",";","<","=",">","?","@","A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z","[","\\","]","^","_","`","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z","{","|","}","~","ġ","Ģ","ģ","Ĥ","ĥ","Ħ","ħ","Ĩ","ĩ","Ī","ī","Ĭ","ĭ","Į","į","İ","ı","IJ","ij","Ĵ","ĵ","Ķ","ķ","ĸ","Ĺ","ĺ","Ļ","ļ","Ľ","ľ","Ŀ","ŀ","Ł","ł","¡","¢","£","¤","¥","¦","§","¨","©","ª","«","¬","Ń","®","¯","°","±","²","³","´","µ","","·","¸","¹","º","»","¼","½","¾","¿","À","Á","Â","Ã","Ä","Å","Æ","Ç","È","É","Ê","Ë","Ì","Í","Î","Ï","Ð","Ñ","Ò","Ó","Ô","Õ","Ö","×","Ø","Ù","Ú","Û","Ü","Ý","Þ","ß","à","á","â","ã","ä","å","æ","ç","è","é","ê","ë","ì","í","î","ï","ð","ñ","ò","ó","ô","õ","ö","÷","ø","ù","ú","û","ü","ý","þ","ÿ"};
static void replaceAll(std::string& str, const std::string& from, const std::string& to) {
if(from.empty())
return;
size_t start_pos = 0;
while((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
}
}
void read_rwkv_vocab() void read_rwkv_vocab()
{ {
@ -15,10 +26,17 @@ void read_rwkv_vocab()
std::ifstream myfile(filepath); std::ifstream myfile(filepath);
if (myfile.is_open()) if (myfile.is_open())
{ {
int slen = special.size();
while (myfile.good()) while (myfile.good())
{ {
getline(myfile, line); getline(myfile, line);
rwkv_vocab.push_back(line); for(int i=0;i<slen;++i)
{
std::string swapped = "";
swapped.push_back((char)i);
replaceAll(line,special[i],swapped);
}
rwkv_vocab.push_back(line);
} }
myfile.close(); myfile.close();
} }

View file

@ -1,13 +1,45 @@
import json import json,os
with open("rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
tr = "{"
for i in sortedbd:
tr += "\""+i[0]+"\","
tr += "}"
print(tr)
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
encoder = json.load(f) encoder = json.load(f)
s = "" s = ""
with open("rwkv_vocab.embd", "w", encoding="utf-8") as f2: with open("rwkv_vocab.embd", "w", encoding="utf-8") as f2:
for key in encoder: for key in encoder:
#key = bytearray([byte_decoder[c] for c in key]).decode('utf-8','ignore')
# key = key.replace("\\","\\\\") # key = key.replace("\\","\\\\")
# key = key.replace("\"","\\\"") # key = key.replace("\"","\\\"")
# s += "\""+key+"\",\n" # s += "\""+key+"\",\n"
s += key +"\n" s += key +"\n"
f2.write(s) f2.write(s)
print("OK") print("OK")