rwkv integration completed
This commit is contained in:
parent
55e0fbf024
commit
5d9f5b28a6
4 changed files with 99 additions and 35 deletions
|
@ -128,7 +128,7 @@ extern "C"
|
|||
return true;
|
||||
}
|
||||
}
|
||||
else if(file_format==FileFormat::RWKV_1)
|
||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
ModelLoadResult lr = gpttype_load_model(inputs, file_format);
|
||||
|
|
|
@ -45,6 +45,7 @@ static gpt_neox_v2_model neox_ctx_v2;
|
|||
static gpt_neox_model neox_ctx_v3;
|
||||
|
||||
static rwkv_v2_context * rwkv_ctx_v2;
|
||||
static rwkv_context * rwkv_ctx_v3;
|
||||
static llama_v2_context_params llama_ctx_params_v2;
|
||||
static llama_context_params llama_ctx_params;
|
||||
static llama_v2_context * llama_ctx_v2;
|
||||
|
@ -389,45 +390,78 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if (file_format == FileFormat::RWKV_1)
|
||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
|
||||
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
|
||||
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
n_batch = 1;
|
||||
|
||||
std::string word;
|
||||
read_rwkv_vocab();
|
||||
int vocabsiz = rwkv_vocab.size();
|
||||
for (int i = 0; i < vocabsiz; i++) {
|
||||
for (int i = 0; i < vocabsiz; i++)
|
||||
{
|
||||
uint32_t len;
|
||||
word = rwkv_vocab[i];
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
printf("\nRWKV Vocab: %u\n",vocabsiz);
|
||||
|
||||
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
if(!testeval)
|
||||
{
|
||||
printf("\nError: RWKV Init Eval Failed!\n");
|
||||
}
|
||||
printf("\nRWKV Vocab: %u\n", vocabsiz);
|
||||
logits.resize(vocabsiz);
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz);
|
||||
|
||||
if (rwkv_ctx_v2 == NULL)
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
return ModelLoadResult::FAIL;
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
|
||||
printf("\nRWKV old Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
|
||||
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
|
||||
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
if (!testeval)
|
||||
{
|
||||
printf("\nError: RWKV old Init Eval Failed!\n");
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * vocabsiz);
|
||||
|
||||
if (rwkv_ctx_v2 == NULL)
|
||||
{
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
|
||||
auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding;
|
||||
|
||||
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||
rwkv_ctx_v3->state_out = (float *)malloc(statebufsiz);
|
||||
rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz);
|
||||
rwkv_ctx_v3->state_in = nullptr;
|
||||
|
||||
bool testeval = rwkv_eval(rwkv_ctx_v3, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
if (!testeval)
|
||||
{
|
||||
printf("\nError: RWKV Init Eval Failed!\n");
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * vocabsiz);
|
||||
|
||||
if (rwkv_ctx_v3 == NULL)
|
||||
{
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if (file_format == FileFormat::GPT2_1)
|
||||
{
|
||||
|
@ -741,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
n_past = 0;
|
||||
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||
}
|
||||
|
@ -755,7 +789,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::GPT2_1 ||
|
||||
file_format == FileFormat::GPTJ_1 ||
|
||||
file_format == FileFormat::GPTJ_2 ||
|
||||
file_format == FileFormat::RWKV_1);
|
||||
file_format == FileFormat::RWKV_1 ||
|
||||
file_format==FileFormat::RWKV_2);
|
||||
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
||||
// bool blasmode = false;
|
||||
int original_batch = params.n_batch;
|
||||
|
@ -834,16 +869,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
n_vocab = neox_ctx_v3.hparams.n_vocab;
|
||||
}
|
||||
else if(file_format == FileFormat::RWKV_1)
|
||||
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||
if(n_past==0)
|
||||
{
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
if(file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_ctx_v3->state_in = nullptr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
}
|
||||
|
||||
//if it's empty, push in the final previous token
|
||||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||
{
|
||||
|
@ -909,11 +959,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
evalres = (llama_eval(llama_ctx_v3, embd.data(), embdsize, n_past, params.n_threads)==0);
|
||||
}
|
||||
else if(file_format==FileFormat::RWKV_1)
|
||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size());
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
}
|
||||
}
|
||||
else if(file_format==FileFormat::GPT2_1)
|
||||
{
|
||||
|
|
|
@ -212,6 +212,10 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
{
|
||||
fileformat = FileFormat::RWKV_1;
|
||||
}
|
||||
else if(temp==101)
|
||||
{
|
||||
fileformat = FileFormat::RWKV_2;
|
||||
}
|
||||
}
|
||||
else if(magic == 0x67676a74) //v3 format ggjt
|
||||
{
|
||||
|
|
|
@ -34,6 +34,7 @@ enum FileFormat
|
|||
GPT2_4=203, //using 16bit scalar
|
||||
|
||||
RWKV_1=300,
|
||||
RWKV_2=301,
|
||||
|
||||
NEOX_1=400,
|
||||
NEOX_2=401,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue