rwkv integration completed
This commit is contained in:
parent
55e0fbf024
commit
5d9f5b28a6
4 changed files with 99 additions and 35 deletions
|
@ -128,7 +128,7 @@ extern "C"
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(file_format==FileFormat::RWKV_1)
|
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||||
ModelLoadResult lr = gpttype_load_model(inputs, file_format);
|
ModelLoadResult lr = gpttype_load_model(inputs, file_format);
|
||||||
|
|
|
@ -45,6 +45,7 @@ static gpt_neox_v2_model neox_ctx_v2;
|
||||||
static gpt_neox_model neox_ctx_v3;
|
static gpt_neox_model neox_ctx_v3;
|
||||||
|
|
||||||
static rwkv_v2_context * rwkv_ctx_v2;
|
static rwkv_v2_context * rwkv_ctx_v2;
|
||||||
|
static rwkv_context * rwkv_ctx_v3;
|
||||||
static llama_v2_context_params llama_ctx_params_v2;
|
static llama_v2_context_params llama_ctx_params_v2;
|
||||||
static llama_context_params llama_ctx_params;
|
static llama_context_params llama_ctx_params;
|
||||||
static llama_v2_context * llama_ctx_v2;
|
static llama_v2_context * llama_ctx_v2;
|
||||||
|
@ -389,7 +390,23 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
return ModelLoadResult::SUCCESS;
|
return ModelLoadResult::SUCCESS;
|
||||||
}
|
}
|
||||||
else if (file_format == FileFormat::RWKV_1)
|
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
|
{
|
||||||
|
n_batch = 1;
|
||||||
|
std::string word;
|
||||||
|
read_rwkv_vocab();
|
||||||
|
int vocabsiz = rwkv_vocab.size();
|
||||||
|
for (int i = 0; i < vocabsiz; i++)
|
||||||
|
{
|
||||||
|
uint32_t len;
|
||||||
|
word = rwkv_vocab[i];
|
||||||
|
vocab.token_to_id[word] = i;
|
||||||
|
vocab.id_to_token[i] = word;
|
||||||
|
}
|
||||||
|
printf("\nRWKV Vocab: %u\n", vocabsiz);
|
||||||
|
logits.resize(vocabsiz);
|
||||||
|
|
||||||
|
if (file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
|
||||||
|
@ -398,30 +415,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||||
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||||
|
|
||||||
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
printf("\nRWKV old Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||||
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
|
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
|
||||||
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
|
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
|
||||||
rwkv_ctx_v2->state_in = nullptr;
|
rwkv_ctx_v2->state_in = nullptr;
|
||||||
n_batch = 1;
|
|
||||||
|
|
||||||
std::string word;
|
|
||||||
read_rwkv_vocab();
|
|
||||||
int vocabsiz = rwkv_vocab.size();
|
|
||||||
for (int i = 0; i < vocabsiz; i++) {
|
|
||||||
uint32_t len;
|
|
||||||
word = rwkv_vocab[i];
|
|
||||||
vocab.token_to_id[word] = i;
|
|
||||||
vocab.id_to_token[i] = word;
|
|
||||||
}
|
|
||||||
printf("\nRWKV Vocab: %u\n",vocabsiz);
|
|
||||||
|
|
||||||
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||||
if(!testeval)
|
if (!testeval)
|
||||||
{
|
{
|
||||||
printf("\nError: RWKV Init Eval Failed!\n");
|
printf("\nError: RWKV old Init Eval Failed!\n");
|
||||||
}
|
}
|
||||||
logits.resize(vocabsiz);
|
|
||||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz);
|
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * vocabsiz);
|
||||||
|
|
||||||
if (rwkv_ctx_v2 == NULL)
|
if (rwkv_ctx_v2 == NULL)
|
||||||
{
|
{
|
||||||
|
@ -429,6 +434,35 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
return ModelLoadResult::SUCCESS;
|
return ModelLoadResult::SUCCESS;
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
|
||||||
|
//setup buffers for rwkv state
|
||||||
|
auto padding = 512u;
|
||||||
|
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
|
||||||
|
auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
|
||||||
|
|
||||||
|
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||||
|
rwkv_ctx_v3->state_out = (float *)malloc(statebufsiz);
|
||||||
|
rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz);
|
||||||
|
rwkv_ctx_v3->state_in = nullptr;
|
||||||
|
|
||||||
|
bool testeval = rwkv_eval(rwkv_ctx_v3, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
|
if (!testeval)
|
||||||
|
{
|
||||||
|
printf("\nError: RWKV Init Eval Failed!\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * vocabsiz);
|
||||||
|
|
||||||
|
if (rwkv_ctx_v3 == NULL)
|
||||||
|
{
|
||||||
|
return ModelLoadResult::FAIL;
|
||||||
|
}
|
||||||
|
return ModelLoadResult::SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
||||||
else if (file_format == FileFormat::GPT2_1)
|
else if (file_format == FileFormat::GPT2_1)
|
||||||
{
|
{
|
||||||
ModelLoadResult res = legacy_gpt2_model_load(params.model, gpt2_ctx_v1, vocab, file_format);
|
ModelLoadResult res = legacy_gpt2_model_load(params.model, gpt2_ctx_v1, vocab, file_format);
|
||||||
|
@ -741,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
|
||||||
if (file_format == FileFormat::RWKV_1)
|
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||||
}
|
}
|
||||||
|
@ -755,7 +789,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
file_format == FileFormat::GPT2_1 ||
|
file_format == FileFormat::GPT2_1 ||
|
||||||
file_format == FileFormat::GPTJ_1 ||
|
file_format == FileFormat::GPTJ_1 ||
|
||||||
file_format == FileFormat::GPTJ_2 ||
|
file_format == FileFormat::GPTJ_2 ||
|
||||||
file_format == FileFormat::RWKV_1);
|
file_format == FileFormat::RWKV_1 ||
|
||||||
|
file_format==FileFormat::RWKV_2);
|
||||||
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
||||||
// bool blasmode = false;
|
// bool blasmode = false;
|
||||||
int original_batch = params.n_batch;
|
int original_batch = params.n_batch;
|
||||||
|
@ -834,16 +869,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
n_vocab = neox_ctx_v3.hparams.n_vocab;
|
n_vocab = neox_ctx_v3.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
else if(file_format == FileFormat::RWKV_1)
|
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||||
if(n_past==0)
|
if(n_past==0)
|
||||||
|
{
|
||||||
|
if(file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
rwkv_ctx_v2->state_in = nullptr;
|
rwkv_ctx_v2->state_in = nullptr;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
{
|
||||||
|
rwkv_ctx_v3->state_in = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||||
|
}
|
||||||
|
|
||||||
//if it's empty, push in the final previous token
|
//if it's empty, push in the final previous token
|
||||||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||||
{
|
{
|
||||||
|
@ -909,12 +959,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
evalres = (llama_eval(llama_ctx_v3, embd.data(), embdsize, n_past, params.n_threads)==0);
|
evalres = (llama_eval(llama_ctx_v3, embd.data(), embdsize, n_past, params.n_threads)==0);
|
||||||
}
|
}
|
||||||
else if(file_format==FileFormat::RWKV_1)
|
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
|
{
|
||||||
|
if (file_format == FileFormat::RWKV_1)
|
||||||
{
|
{
|
||||||
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size());
|
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
|
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||||
|
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||||
|
}
|
||||||
|
}
|
||||||
else if(file_format==FileFormat::GPT2_1)
|
else if(file_format==FileFormat::GPT2_1)
|
||||||
{
|
{
|
||||||
evalres = legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
evalres = legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||||
|
|
|
@ -212,6 +212,10 @@ void print_tok_vec(std::vector<float> &embd)
|
||||||
{
|
{
|
||||||
fileformat = FileFormat::RWKV_1;
|
fileformat = FileFormat::RWKV_1;
|
||||||
}
|
}
|
||||||
|
else if(temp==101)
|
||||||
|
{
|
||||||
|
fileformat = FileFormat::RWKV_2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if(magic == 0x67676a74) //v3 format ggjt
|
else if(magic == 0x67676a74) //v3 format ggjt
|
||||||
{
|
{
|
||||||
|
|
|
@ -34,6 +34,7 @@ enum FileFormat
|
||||||
GPT2_4=203, //using 16bit scalar
|
GPT2_4=203, //using 16bit scalar
|
||||||
|
|
||||||
RWKV_1=300,
|
RWKV_1=300,
|
||||||
|
RWKV_2=301,
|
||||||
|
|
||||||
NEOX_1=400,
|
NEOX_1=400,
|
||||||
NEOX_2=401,
|
NEOX_2=401,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue