rwkv is done
This commit is contained in:
parent
a76b15b581
commit
ea01771dd5
5 changed files with 62 additions and 12 deletions
4
Makefile
4
Makefile
|
@ -130,8 +130,8 @@ ifdef LLAMA_GPROF
|
|||
CXXFLAGS += -pg
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
CFLAGS += -mcpu=native
|
||||
CXXFLAGS += -mcpu=native
|
||||
CFLAGS +=
|
||||
CXXFLAGS +=
|
||||
endif
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
# Raspberry Pi 1, 2, 3
|
||||
|
|
|
@ -232,7 +232,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
|
||||
print_tok_vec(embd_inp);
|
||||
|
||||
//truncate to front of the prompt if its too long
|
||||
int32_t nctx = params.n_ctx;
|
||||
|
@ -258,7 +257,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
|
||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||
bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
|
||||
bool approved_format = (file_format==FileFormat::GPT2_2 || file_format==FileFormat::GPTJ_3);
|
||||
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
||||
// bool blasmode = false;
|
||||
int original_batch = params.n_batch;
|
||||
|
@ -304,6 +303,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
else if(file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||
rwkv_context_v1->state_in = nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -333,9 +333,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
if(file_format==FileFormat::RWKV_1)
|
||||
{
|
||||
printf("\nsiz:%d val:%d\n",embd.size(),embd[0]);
|
||||
evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out);
|
||||
memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size());
|
||||
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
|
||||
}
|
||||
else if(file_format==FileFormat::GPT2_1)
|
||||
{
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -6,6 +6,17 @@
|
|||
#include "expose.h"
|
||||
|
||||
std::vector<std::string> rwkv_vocab;
|
||||
std::vector<std::string> special = {"Ā","ā","Ă","ă","Ą","ą","Ć","ć","Ĉ","ĉ","Ċ","ċ","Č","č","Ď","ď","Đ","đ","Ē","ē","Ĕ","ĕ","Ė","ė","Ę","ę","Ě","ě","Ĝ","ĝ","Ğ","ğ","Ġ","!","\"","#","$","%","&","\'","(",")","*","+",",","-",".","/","0","1","2","3","4","5","6","7","8","9",":",";","<","=",">","?","@","A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z","[","\\","]","^","_","`","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z","{","|","}","~","ġ","Ģ","ģ","Ĥ","ĥ","Ħ","ħ","Ĩ","ĩ","Ī","ī","Ĭ","ĭ","Į","į","İ","ı","IJ","ij","Ĵ","ĵ","Ķ","ķ","ĸ","Ĺ","ĺ","Ļ","ļ","Ľ","ľ","Ŀ","ŀ","Ł","ł","¡","¢","£","¤","¥","¦","§","¨","©","ª","«","¬","Ń","®","¯","°","±","²","³","´","µ","¶","·","¸","¹","º","»","¼","½","¾","¿","À","Á","Â","Ã","Ä","Å","Æ","Ç","È","É","Ê","Ë","Ì","Í","Î","Ï","Ð","Ñ","Ò","Ó","Ô","Õ","Ö","×","Ø","Ù","Ú","Û","Ü","Ý","Þ","ß","à","á","â","ã","ä","å","æ","ç","è","é","ê","ë","ì","í","î","ï","ð","ñ","ò","ó","ô","õ","ö","÷","ø","ù","ú","û","ü","ý","þ","ÿ"};
|
||||
|
||||
static void replaceAll(std::string& str, const std::string& from, const std::string& to) {
|
||||
if(from.empty())
|
||||
return;
|
||||
size_t start_pos = 0;
|
||||
while((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||
str.replace(start_pos, from.length(), to);
|
||||
start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
|
||||
}
|
||||
}
|
||||
|
||||
void read_rwkv_vocab()
|
||||
{
|
||||
|
@ -15,9 +26,16 @@ void read_rwkv_vocab()
|
|||
std::ifstream myfile(filepath);
|
||||
if (myfile.is_open())
|
||||
{
|
||||
int slen = special.size();
|
||||
while (myfile.good())
|
||||
{
|
||||
getline(myfile, line);
|
||||
for(int i=0;i<slen;++i)
|
||||
{
|
||||
std::string swapped = "";
|
||||
swapped.push_back((char)i);
|
||||
replaceAll(line,special[i],swapped);
|
||||
}
|
||||
rwkv_vocab.push_back(line);
|
||||
}
|
||||
myfile.close();
|
||||
|
|
|
@ -1,9 +1,41 @@
|
|||
import json
|
||||
with open("rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
|
||||
import json,os
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||
sortedbd = sorted(byte_decoder.items(), key=lambda kv: kv[1])
|
||||
tr = "{"
|
||||
for i in sortedbd:
|
||||
tr += "\""+i[0]+"\","
|
||||
tr += "}"
|
||||
print(tr)
|
||||
|
||||
with open((os.path.dirname(os.path.realpath(__file__))+"/") + "rwkv_orig_vocab.json", "r", encoding="utf-8") as f:
|
||||
encoder = json.load(f)
|
||||
s = ""
|
||||
with open("rwkv_vocab.embd", "w", encoding="utf-8") as f2:
|
||||
for key in encoder:
|
||||
#key = bytearray([byte_decoder[c] for c in key]).decode('utf-8','ignore')
|
||||
# key = key.replace("\\","\\\\")
|
||||
# key = key.replace("\"","\\\"")
|
||||
# s += "\""+key+"\",\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue