rwkv integration completed

This commit is contained in:
Concedo 2023-05-28 00:48:56 +08:00
parent 55e0fbf024
commit 5d9f5b28a6
4 changed files with 99 additions and 35 deletions

View file

@ -128,7 +128,7 @@ extern "C"
return true;
}
}
else if(file_format==FileFormat::RWKV_1)
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format);
ModelLoadResult lr = gpttype_load_model(inputs, file_format);

View file

@ -45,6 +45,7 @@ static gpt_neox_v2_model neox_ctx_v2;
static gpt_neox_model neox_ctx_v3;
static rwkv_v2_context * rwkv_ctx_v2;
static rwkv_context * rwkv_ctx_v3;
static llama_v2_context_params llama_ctx_params_v2;
static llama_context_params llama_ctx_params;
static llama_v2_context * llama_ctx_v2;
@ -389,45 +390,78 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
return ModelLoadResult::SUCCESS;
}
else if (file_format == FileFormat::RWKV_1)
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
rwkv_ctx_v2->state_in = nullptr;
n_batch = 1;
std::string word;
read_rwkv_vocab();
int vocabsiz = rwkv_vocab.size();
for (int i = 0; i < vocabsiz; i++) {
for (int i = 0; i < vocabsiz; i++)
{
uint32_t len;
word = rwkv_vocab[i];
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
printf("\nRWKV Vocab: %u\n",vocabsiz);
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
if(!testeval)
{
printf("\nError: RWKV Init Eval Failed!\n");
}
printf("\nRWKV Vocab: %u\n", vocabsiz);
logits.resize(vocabsiz);
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz);
if (rwkv_ctx_v2 == NULL)
if (file_format == FileFormat::RWKV_1)
{
return ModelLoadResult::FAIL;
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
printf("\nRWKV old Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
rwkv_ctx_v2->state_in = nullptr;
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
if (!testeval)
{
printf("\nError: RWKV old Init Eval Failed!\n");
}
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * vocabsiz);
if (rwkv_ctx_v2 == NULL)
{
return ModelLoadResult::FAIL;
}
return ModelLoadResult::SUCCESS;
}
else
{
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
rwkv_ctx_v3->state_out = (float *)malloc(statebufsiz);
rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz);
rwkv_ctx_v3->state_in = nullptr;
bool testeval = rwkv_eval(rwkv_ctx_v3, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
if (!testeval)
{
printf("\nError: RWKV Init Eval Failed!\n");
}
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * vocabsiz);
if (rwkv_ctx_v3 == NULL)
{
return ModelLoadResult::FAIL;
}
return ModelLoadResult::SUCCESS;
}
return ModelLoadResult::SUCCESS;
}
else if (file_format == FileFormat::GPT2_1)
{
@ -741,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
if (file_format == FileFormat::RWKV_1)
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
}
@ -755,7 +789,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::RWKV_1);
file_format == FileFormat::RWKV_1 ||
file_format==FileFormat::RWKV_2);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool blasmode = false;
int original_batch = params.n_batch;
@ -834,16 +869,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
n_vocab = neox_ctx_v3.hparams.n_vocab;
}
else if(file_format == FileFormat::RWKV_1)
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
n_vocab = vocab.id_to_token.size(); //handled seperately
if(n_past==0)
{
rwkv_ctx_v2->state_in = nullptr;
if(file_format == FileFormat::RWKV_1)
{
rwkv_ctx_v2->state_in = nullptr;
}
else
{
rwkv_ctx_v3->state_in = nullptr;
}
}
else
{
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
if (file_format == FileFormat::RWKV_1)
{
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
}
else
{
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
}
//if it's empty, push in the final previous token
if(embd_inp.size()==0 && current_context_tokens.size()>0)
{
@ -909,11 +959,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
evalres = (llama_eval(llama_ctx_v3, embd.data(), embdsize, n_past, params.n_threads)==0);
}
else if(file_format==FileFormat::RWKV_1)
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size());
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
if (file_format == FileFormat::RWKV_1)
{
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * rwkv_vocab.size());
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
}
else
{
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
}
}
else if(file_format==FileFormat::GPT2_1)
{

View file

@ -212,6 +212,10 @@ void print_tok_vec(std::vector<float> &embd)
{
fileformat = FileFormat::RWKV_1;
}
else if(temp==101)
{
fileformat = FileFormat::RWKV_2;
}
}
else if(magic == 0x67676a74) //v3 format ggjt
{

View file

@ -34,6 +34,7 @@ enum FileFormat
GPT2_4=203, //using 16bit scalar
RWKV_1=300,
RWKV_2=301,
NEOX_1=400,
NEOX_2=401,