diff --git a/expose.cpp b/expose.cpp index d07d978c8..c2ae30767 100644 --- a/expose.cpp +++ b/expose.cpp @@ -128,7 +128,7 @@ extern "C" return true; } } - else if(file_format==FileFormat::RWKV_1) + else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format); ModelLoadResult lr = gpttype_load_model(inputs, file_format); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d65836ceb..cedba1286 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -45,6 +45,7 @@ static gpt_neox_v2_model neox_ctx_v2; static gpt_neox_model neox_ctx_v3; static rwkv_v2_context * rwkv_ctx_v2; +static rwkv_context * rwkv_ctx_v3; static llama_v2_context_params llama_ctx_params_v2; static llama_context_params llama_ctx_params; static llama_v2_context * llama_ctx_v2; @@ -389,45 +390,78 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } return ModelLoadResult::SUCCESS; } - else if (file_format == FileFormat::RWKV_1) + else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { - rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads); - - //setup buffers for rwkv state - auto padding = 512u; - auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; - auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; - - printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz); - rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz); - rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz); - rwkv_ctx_v2->state_in = nullptr; n_batch = 1; - std::string word; read_rwkv_vocab(); int vocabsiz = rwkv_vocab.size(); - for (int i = 0; i < vocabsiz; i++) { + for (int i = 0; i < vocabsiz; i++) + { uint32_t len; word = rwkv_vocab[i]; vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; } - printf("\nRWKV Vocab: %u\n",vocabsiz); - - bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); - if(!testeval) - { - printf("\nError: RWKV Init Eval Failed!\n"); - } + printf("\nRWKV Vocab: %u\n", vocabsiz); logits.resize(vocabsiz); - memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz); - if (rwkv_ctx_v2 == NULL) + if (file_format == FileFormat::RWKV_1) { - return ModelLoadResult::FAIL; + rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads); + + //setup buffers for rwkv state + auto padding = 512u; + auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; + auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; + + printf("\nRWKV old Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz); + rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz); + rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz); + rwkv_ctx_v2->state_in = nullptr; + + bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); + if (!testeval) + { + printf("\nError: RWKV old Init Eval Failed!\n"); + } + + memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * vocabsiz); + + if (rwkv_ctx_v2 == NULL) + { + return ModelLoadResult::FAIL; + } + return ModelLoadResult::SUCCESS; + } + else + { + rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads); + + //setup buffers for rwkv state + auto padding = 512u; + auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding; + auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v3) * sizeof(float) + padding; + + printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz); + rwkv_ctx_v3->state_out = (float *)malloc(statebufsiz); + rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz); + rwkv_ctx_v3->state_in = nullptr; + + bool testeval = rwkv_eval(rwkv_ctx_v3, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + if (!testeval) + { + printf("\nError: RWKV Init Eval Failed!\n"); + } + + memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * vocabsiz); + + if (rwkv_ctx_v3 == NULL) + { + return ModelLoadResult::FAIL; + } + return ModelLoadResult::SUCCESS; } - return ModelLoadResult::SUCCESS; } else if (file_format == FileFormat::GPT2_1) { @@ -741,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; - if (file_format == FileFormat::RWKV_1) + if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true); } @@ -755,7 +789,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o file_format == FileFormat::GPT2_1 || file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || - file_format == FileFormat::RWKV_1); + file_format == FileFormat::RWKV_1 || + file_format==FileFormat::RWKV_2); bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas()); // bool blasmode = false; int original_batch = params.n_batch; @@ -834,16 +869,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { n_vocab = neox_ctx_v3.hparams.n_vocab; } - else if(file_format == FileFormat::RWKV_1) + else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { n_vocab = vocab.id_to_token.size(); //handled seperately if(n_past==0) { - rwkv_ctx_v2->state_in = nullptr; + if(file_format == FileFormat::RWKV_1) + { + rwkv_ctx_v2->state_in = nullptr; + } + else + { + rwkv_ctx_v3->state_in = nullptr; + } } else { - rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; + if (file_format == FileFormat::RWKV_1) + { + rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; + } + else + { + rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out; + } + //if it's empty, push in the final previous token if(embd_inp.size()==0 && current_context_tokens.size()>0) { @@ -909,11 +959,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { evalres = (llama_eval(llama_ctx_v3, embd.data(), embdsize, n_past, params.n_threads)==0); } - else if(file_format==FileFormat::RWKV_1) + else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { - evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); - memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size()); - rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; + if (file_format == FileFormat::RWKV_1) + { + evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); + memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float) * rwkv_vocab.size()); + rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; + } + else + { + evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size()); + rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out; + } } else if(file_format==FileFormat::GPT2_1) { diff --git a/model_adapter.cpp b/model_adapter.cpp index 6864e0d37..042745912 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -212,6 +212,10 @@ void print_tok_vec(std::vector &embd) { fileformat = FileFormat::RWKV_1; } + else if(temp==101) + { + fileformat = FileFormat::RWKV_2; + } } else if(magic == 0x67676a74) //v3 format ggjt { diff --git a/model_adapter.h b/model_adapter.h index 4b418d4d8..ab6769910 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -34,6 +34,7 @@ enum FileFormat GPT2_4=203, //using 16bit scalar RWKV_1=300, + RWKV_2=301, NEOX_1=400, NEOX_2=401,