From 55e0fbf0247c285f494dbee75a713650fd186d71 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 27 May 2023 22:45:28 +0800 Subject: [PATCH] wip integrating new rwkv --- gpttype_adapter.cpp | 31 +- otherarch/rwkv_v2.cpp | 250 ++++---- otherarch/rwkv_v2.h | 40 +- otherarch/rwkv_v3.cpp | 948 +++++++++++++++++++++++++++++ otherarch/rwkv_v3.h | 125 ++++ otherarch/tools/convert_pt_rwkv.py | 92 +-- 6 files changed, 1239 insertions(+), 247 deletions(-) create mode 100644 otherarch/rwkv_v3.cpp create mode 100644 otherarch/rwkv_v3.h diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d82175100..d65836ceb 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -23,6 +23,7 @@ #include "gpt2_v2.cpp" #include "gpt2_v3.cpp" #include "rwkv_v2.cpp" +#include "rwkv_v3.cpp" #include "neox_v2.cpp" #include "neox_v3.cpp" @@ -43,7 +44,7 @@ static gpt2_model gpt2_ctx_v3; static gpt_neox_v2_model neox_ctx_v2; static gpt_neox_model neox_ctx_v3; -static rwkv_context * rwkv_ctx_v1; +static rwkv_v2_context * rwkv_ctx_v2; static llama_v2_context_params llama_ctx_params_v2; static llama_context_params llama_ctx_params; static llama_v2_context * llama_ctx_v2; @@ -390,17 +391,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } else if (file_format == FileFormat::RWKV_1) { - rwkv_ctx_v1 = rwkv_init_from_file(modelname.c_str(), n_threads); + rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads); //setup buffers for rwkv state auto padding = 512u; - auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding; - auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding; + auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; + auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding; printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz); - rwkv_ctx_v1->state_out = (float *)malloc(statebufsiz); - rwkv_ctx_v1->logits_out = (float *)malloc(logitbufsiz); - rwkv_ctx_v1->state_in = nullptr; + rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz); + rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz); + rwkv_ctx_v2->state_in = nullptr; n_batch = 1; std::string word; @@ -414,15 +415,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } printf("\nRWKV Vocab: %u\n",vocabsiz); - bool testeval = rwkv_eval(rwkv_ctx_v1, 0, rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out); + bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); if(!testeval) { printf("\nError: RWKV Init Eval Failed!\n"); } logits.resize(vocabsiz); - memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*vocabsiz); + memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz); - if (rwkv_ctx_v1 == NULL) + if (rwkv_ctx_v2 == NULL) { return ModelLoadResult::FAIL; } @@ -838,11 +839,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o n_vocab = vocab.id_to_token.size(); //handled seperately if(n_past==0) { - rwkv_ctx_v1->state_in = nullptr; + rwkv_ctx_v2->state_in = nullptr; } else { - rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out; + rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; //if it's empty, push in the final previous token if(embd_inp.size()==0 && current_context_tokens.size()>0) { @@ -910,9 +911,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } else if(file_format==FileFormat::RWKV_1) { - evalres = rwkv_eval(rwkv_ctx_v1, embd[0], rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out); - memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*rwkv_vocab.size()); - rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out; + evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out); + memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size()); + rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out; } else if(file_format==FileFormat::GPT2_1) { diff --git a/otherarch/rwkv_v2.cpp b/otherarch/rwkv_v2.cpp index a58830dce..d627a13f0 100644 --- a/otherarch/rwkv_v2.cpp +++ b/otherarch/rwkv_v2.cpp @@ -23,7 +23,7 @@ // --- Utilities --- // Checks that x is not false. If x is false, prints fancy message to stderr and returns 0. -#define RWKV_ASSERT_FALSE(x, ...) \ +#define RWKV_V2_ASSERT_FALSE(x, ...) \ do { \ if (!(x)) { \ fprintf(stderr, __VA_ARGS__); \ @@ -33,7 +33,7 @@ } while (0) // Checks that x is not false. If x is false, prints fancy message to stderr and returns NULL. -#define RWKV_ASSERT_NULL(x, ...) \ +#define RWKV_V2_ASSERT_NULL(x, ...) \ do { \ if (!(x)) { \ fprintf(stderr, __VA_ARGS__); \ @@ -43,16 +43,16 @@ } while (0) // Reads single int32 value from a file. -bool read_int32(FILE * file, int32_t * dest) { - RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file"); +bool rwkv_v2_read_int32(FILE * file, int32_t * dest) { + RWKV_V2_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file"); return true; } #define GGML_V2_TYPE_UNKNOWN GGML_V2_TYPE_COUNT -#define FORMAT_TYPE_COUNT 10 +#define RWKV_V2_FORMAT_TYPE_COUNT 10 -static const ggml_v2_type FORMAT_TYPE_TO_GGML_V2_TYPE[FORMAT_TYPE_COUNT] = { +static const ggml_v2_type FORMAT_TYPE_TO_GGML_V2_TYPE[RWKV_V2_FORMAT_TYPE_COUNT] = { GGML_V2_TYPE_F32, GGML_V2_TYPE_F16, GGML_V2_TYPE_Q4_0, @@ -65,7 +65,7 @@ static const ggml_v2_type FORMAT_TYPE_TO_GGML_V2_TYPE[FORMAT_TYPE_COUNT] = { GGML_V2_TYPE_Q8_0 }; -static int32_t format_name_to_format_type(const char * format_name) { +static int32_t rwkv_v2_format_name_to_format_type(const char * format_name) { if (strcmp(format_name, "Q4_0") == 0) return 2; if (strcmp(format_name, "Q4_1") == 0) return 3; if (strcmp(format_name, "Q4_2") == 0) return 5; @@ -78,7 +78,7 @@ static int32_t format_name_to_format_type(const char * format_name) { // --- Model definition and loading utilities --- -struct rwkv_layer { +struct rwkv_v2_layer { struct ggml_v2_tensor * ln1_weight; struct ggml_v2_tensor * ln1_bias; @@ -104,7 +104,7 @@ struct rwkv_layer { struct ggml_v2_tensor * ffn_receptance; }; -struct rwkv_model { +struct rwkv_v2_model { int32_t n_vocab; int32_t n_layer; int32_t n_embed; @@ -116,7 +116,7 @@ struct rwkv_model { struct ggml_v2_tensor * ln0_weight; struct ggml_v2_tensor * ln0_bias; - std::vector layers; + std::vector layers; struct ggml_v2_tensor * ln_out_weight; struct ggml_v2_tensor * ln_out_bias; @@ -126,64 +126,64 @@ struct rwkv_model { // Finds model parameter by key and sets it into dest. // If the parameter was not found, returns false. -bool set_parameter(std::unordered_map * parameters, char * key, struct ggml_v2_tensor ** dest) { +bool rwkv_v2_set_parameter(std::unordered_map * parameters, char * key, struct ggml_v2_tensor ** dest) { struct ggml_v2_tensor * parameter = (*parameters)[key]; - RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key); + RWKV_V2_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key); *dest = parameter; return true; } // Finds block parameter by block index and key and sets it into dest. // If the parameter was not found, returns false. -bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, char * key, struct ggml_v2_tensor ** dest) { +bool rwkv_v2_set_block_parameter(std::unordered_map * parameters, int32_t block_index, char * key, struct ggml_v2_tensor ** dest) { char full_key[128]; sprintf(full_key, "blocks.%d.%s", block_index, key); - return set_parameter(parameters, full_key, dest); + return rwkv_v2_set_parameter(parameters, full_key, dest); } // --- Operators --- -void rwkv_exp_impl(const int n_cols, float * dest, const float * src) { +void rwkv_v2_exp_impl(const int n_cols, float * dest, const float * src) { for (int i = 0; i < n_cols; i++) { dest[i] = expf(src[i]); } } -void rwkv_1_minus_x_impl(const int n_cols, float * dest, const float * src) { +void rwkv_v2_1_minus_x_impl(const int n_cols, float * dest, const float * src) { for (int i = 0; i < n_cols; i++) { dest[i] = 1.0F - src[i]; } } -void rwkv_sigmoid_impl(const int n_cols, float * dest, const float * src) { +void rwkv_v2_sigmoid_impl(const int n_cols, float * dest, const float * src) { for (int i = 0; i < n_cols; i++) { dest[i] = 1.0F / (1.0F + expf(-src[i])); } } -void rwkv_max_impl(const int n_cols, float * dest, const float * src0, const float * src1) { +void rwkv_v2_max_impl(const int n_cols, float * dest, const float * src0, const float * src1) { for (int i = 0; i < n_cols; i++) { dest[i] = fmaxf(src0[i], src1[i]); } } -struct ggml_v2_tensor * rwkv_exp(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { - return ggml_v2_map_unary_f32(ctx, x, rwkv_exp_impl); +struct ggml_v2_tensor * rwkv_v2_exp(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { + return ggml_v2_map_unary_f32(ctx, x, rwkv_v2_exp_impl); } -struct ggml_v2_tensor * rwkv_1_minus_x(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { - return ggml_v2_map_unary_f32(ctx, x, rwkv_1_minus_x_impl); +struct ggml_v2_tensor * rwkv_v2_1_minus_x(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { + return ggml_v2_map_unary_f32(ctx, x, rwkv_v2_1_minus_x_impl); } -struct ggml_v2_tensor * rwkv_sigmoid(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { - return ggml_v2_map_unary_f32(ctx, x, rwkv_sigmoid_impl); +struct ggml_v2_tensor * rwkv_v2_sigmoid(ggml_v2_context * ctx, struct ggml_v2_tensor * x) { + return ggml_v2_map_unary_f32(ctx, x, rwkv_v2_sigmoid_impl); } -struct ggml_v2_tensor * rwkv_max(ggml_v2_context * ctx, struct ggml_v2_tensor * x, struct ggml_v2_tensor * y) { - return ggml_v2_map_binary_f32(ctx, x, y, rwkv_max_impl); +struct ggml_v2_tensor * rwkv_v2_max(ggml_v2_context * ctx, struct ggml_v2_tensor * x, struct ggml_v2_tensor * y) { + return ggml_v2_map_binary_f32(ctx, x, y, rwkv_v2_max_impl); } -struct ggml_v2_tensor * rwkv_layer_norm(ggml_v2_context * ctx, struct ggml_v2_tensor * x, struct ggml_v2_tensor * weight, struct ggml_v2_tensor * bias) { +struct ggml_v2_tensor * rwkv_v2_layer_norm(ggml_v2_context * ctx, struct ggml_v2_tensor * x, struct ggml_v2_tensor * weight, struct ggml_v2_tensor * bias) { // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` // Looks like ggml_v2_norm does the first part, we only need to apply weight & bias. x = ggml_v2_norm(ctx, x); @@ -194,8 +194,8 @@ struct ggml_v2_tensor * rwkv_layer_norm(ggml_v2_context * ctx, struct ggml_v2_te // --- Implementation --- -struct rwkv_context { - struct rwkv_model * model; +struct rwkv_v2_context { + struct rwkv_v2_model * model; struct ggml_v2_tensor * token_index; struct ggml_v2_tensor * state; struct ggml_v2_tensor ** state_parts; @@ -208,38 +208,38 @@ struct rwkv_context { float * logits_out = 0; //stores address of output logit buffer }; -struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) { +struct rwkv_v2_context * rwkv_v2_init_from_file(const char * file_path, uint32_t n_threads) { FILE * file = fopen(file_path, "rb"); - RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); + RWKV_V2_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); int32_t magic; - read_int32(file, &magic); - RWKV_ASSERT_NULL(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + rwkv_v2_read_int32(file, &magic); + RWKV_V2_ASSERT_NULL(magic == RWKV_V2_FILE_MAGIC, "Unexpected magic value %d", magic); int32_t version; - read_int32(file, &version); - RWKV_ASSERT_NULL(version == RWKV_FILE_VERSION, "Unsupported file version %d", version); + rwkv_v2_read_int32(file, &version); + RWKV_V2_ASSERT_NULL(version == RWKV_V2_FILE_VERSION, "Unsupported file version %d", version); - struct rwkv_model * model = (struct rwkv_model *) calloc(1, sizeof(struct rwkv_model)); + struct rwkv_v2_model * model = (struct rwkv_v2_model *) calloc(1, sizeof(struct rwkv_v2_model)); - read_int32(file, &(model->n_vocab)); - RWKV_ASSERT_NULL(model->n_vocab > 0, "Non-positive n_vocab %d", model->n_vocab); + rwkv_v2_read_int32(file, &(model->n_vocab)); + RWKV_V2_ASSERT_NULL(model->n_vocab > 0, "Non-positive n_vocab %d", model->n_vocab); - read_int32(file, &(model->n_embed)); - RWKV_ASSERT_NULL(model->n_embed > 0, "Non-positive n_embed %d", model->n_embed); + rwkv_v2_read_int32(file, &(model->n_embed)); + RWKV_V2_ASSERT_NULL(model->n_embed > 0, "Non-positive n_embed %d", model->n_embed); - read_int32(file, &(model->n_layer)); - RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); + rwkv_v2_read_int32(file, &(model->n_layer)); + RWKV_V2_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); - read_int32(file, &(model->data_type)); - RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); + rwkv_v2_read_int32(file, &(model->data_type)); + RWKV_V2_ASSERT_NULL(model->data_type >= 0 && model->data_type < RWKV_V2_FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); - RWKV_ASSERT_NULL( + RWKV_V2_ASSERT_NULL( model->data_type != 4, "Models in Q4_1_O format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" ); - RWKV_ASSERT_NULL( + RWKV_V2_ASSERT_NULL( model->data_type != 6, "Models in Q4_3 format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" ); @@ -249,7 +249,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr { auto fin = std::ifstream(file_path, std::ios::binary); - RWKV_ASSERT_NULL(fin, "Failed to open file %s", file_path); + RWKV_V2_ASSERT_NULL(fin, "Failed to open file %s", file_path); fin.seekg(0, fin.end); file_size = fin.tellg(); fin.close(); @@ -283,20 +283,20 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr break; } - RWKV_ASSERT_NULL(elements_read == 1, "Failed to read dimension count"); - RWKV_ASSERT_NULL(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); + RWKV_V2_ASSERT_NULL(elements_read == 1, "Failed to read dimension count"); + RWKV_V2_ASSERT_NULL(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); int32_t key_length; - read_int32(file, &key_length); - RWKV_ASSERT_NULL(key_length > 0, "Non-positive key length %d", key_length); + rwkv_v2_read_int32(file, &key_length); + RWKV_V2_ASSERT_NULL(key_length > 0, "Non-positive key length %d", key_length); int32_t data_type; - read_int32(file, &data_type); - RWKV_ASSERT_NULL(data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); + rwkv_v2_read_int32(file, &data_type); + RWKV_V2_ASSERT_NULL(data_type >= 0 && data_type < RWKV_V2_FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); ggml_v2_type ggml_v2_data_type = FORMAT_TYPE_TO_GGML_V2_TYPE[data_type]; - RWKV_ASSERT_NULL(ggml_v2_data_type != GGML_V2_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); + RWKV_V2_ASSERT_NULL(ggml_v2_data_type != GGML_V2_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); struct ggml_v2_tensor * tensor; @@ -304,20 +304,20 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr int32_t y = -1; if (dim_count == 1) { - read_int32(file, &x); + rwkv_v2_read_int32(file, &x); tensor = ggml_v2_new_tensor_1d(ctx, ggml_v2_data_type, x); } else if (dim_count == 2) { - read_int32(file, &x); - read_int32(file, &y); + rwkv_v2_read_int32(file, &x); + rwkv_v2_read_int32(file, &y); tensor = ggml_v2_new_tensor_2d(ctx, ggml_v2_data_type, x, y); } else { abort(); } std::string key(key_length, 0); - RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == uint32_t(key_length), "Failed to read parameter key"); + RWKV_V2_ASSERT_NULL(fread(&key[0], 1, key_length, file) == uint32_t(key_length), "Failed to read parameter key"); - RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_v2_nbytes(tensor), file) == ggml_v2_nbytes(tensor), "Failed to read parameter data"); + RWKV_V2_ASSERT_NULL(fread(tensor->data, 1, ggml_v2_nbytes(tensor), file) == ggml_v2_nbytes(tensor), "Failed to read parameter data"); parameters[key] = tensor; } @@ -326,49 +326,49 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr model->layers.resize(model->n_layer); - set_parameter(¶meters, "emb.weight", &(model->emb)); + rwkv_v2_set_parameter(¶meters, "emb.weight", &(model->emb)); - set_parameter(¶meters, "blocks.0.ln0.weight", &(model->ln0_weight)); - set_parameter(¶meters, "blocks.0.ln0.bias", &(model->ln0_bias)); + rwkv_v2_set_parameter(¶meters, "blocks.0.ln0.weight", &(model->ln0_weight)); + rwkv_v2_set_parameter(¶meters, "blocks.0.ln0.bias", &(model->ln0_bias)); for (int i = 0; i < model->n_layer; i++) { - rwkv_layer layer = model->layers[i]; + rwkv_v2_layer layer = model->layers[i]; - set_block_parameter(¶meters, i, "ln1.weight", &(layer.ln1_weight)); - set_block_parameter(¶meters, i, "ln1.bias", &(layer.ln1_bias)); + rwkv_v2_set_block_parameter(¶meters, i, "ln1.weight", &(layer.ln1_weight)); + rwkv_v2_set_block_parameter(¶meters, i, "ln1.bias", &(layer.ln1_bias)); - set_block_parameter(¶meters, i, "att.time_mix_k", &(layer.att_time_mix_k)); - set_block_parameter(¶meters, i, "att.time_mix_v", &(layer.att_time_mix_v)); - set_block_parameter(¶meters, i, "att.time_mix_r", &(layer.att_time_mix_r)); - set_block_parameter(¶meters, i, "att.time_first", &(layer.att_time_first)); - set_block_parameter(¶meters, i, "att.time_decay", &(layer.att_time_decay)); - set_block_parameter(¶meters, i, "att.key.weight", &(layer.att_key)); - set_block_parameter(¶meters, i, "att.value.weight", &(layer.att_value)); - set_block_parameter(¶meters, i, "att.receptance.weight", &(layer.att_receptance)); - set_block_parameter(¶meters, i, "att.output.weight", &(layer.att_output)); + rwkv_v2_set_block_parameter(¶meters, i, "att.time_mix_k", &(layer.att_time_mix_k)); + rwkv_v2_set_block_parameter(¶meters, i, "att.time_mix_v", &(layer.att_time_mix_v)); + rwkv_v2_set_block_parameter(¶meters, i, "att.time_mix_r", &(layer.att_time_mix_r)); + rwkv_v2_set_block_parameter(¶meters, i, "att.time_first", &(layer.att_time_first)); + rwkv_v2_set_block_parameter(¶meters, i, "att.time_decay", &(layer.att_time_decay)); + rwkv_v2_set_block_parameter(¶meters, i, "att.key.weight", &(layer.att_key)); + rwkv_v2_set_block_parameter(¶meters, i, "att.value.weight", &(layer.att_value)); + rwkv_v2_set_block_parameter(¶meters, i, "att.receptance.weight", &(layer.att_receptance)); + rwkv_v2_set_block_parameter(¶meters, i, "att.output.weight", &(layer.att_output)); - set_block_parameter(¶meters, i, "ln2.weight", &(layer.ln2_weight)); - set_block_parameter(¶meters, i, "ln2.bias", &(layer.ln2_bias)); + rwkv_v2_set_block_parameter(¶meters, i, "ln2.weight", &(layer.ln2_weight)); + rwkv_v2_set_block_parameter(¶meters, i, "ln2.bias", &(layer.ln2_bias)); - set_block_parameter(¶meters, i, "ffn.time_mix_k", &(layer.ffn_time_mix_k)); - set_block_parameter(¶meters, i, "ffn.time_mix_r", &(layer.ffn_time_mix_r)); - set_block_parameter(¶meters, i, "ffn.key.weight", &(layer.ffn_key)); - set_block_parameter(¶meters, i, "ffn.value.weight", &(layer.ffn_value)); - set_block_parameter(¶meters, i, "ffn.receptance.weight", &(layer.ffn_receptance)); + rwkv_v2_set_block_parameter(¶meters, i, "ffn.time_mix_k", &(layer.ffn_time_mix_k)); + rwkv_v2_set_block_parameter(¶meters, i, "ffn.time_mix_r", &(layer.ffn_time_mix_r)); + rwkv_v2_set_block_parameter(¶meters, i, "ffn.key.weight", &(layer.ffn_key)); + rwkv_v2_set_block_parameter(¶meters, i, "ffn.value.weight", &(layer.ffn_value)); + rwkv_v2_set_block_parameter(¶meters, i, "ffn.receptance.weight", &(layer.ffn_receptance)); model->layers[i] = layer; } - set_parameter(¶meters, "ln_out.weight", &(model->ln_out_weight)); - set_parameter(¶meters, "ln_out.bias", &(model->ln_out_bias)); + rwkv_v2_set_parameter(¶meters, "ln_out.weight", &(model->ln_out_weight)); + rwkv_v2_set_parameter(¶meters, "ln_out.bias", &(model->ln_out_bias)); - set_parameter(¶meters, "head.weight", &(model->head)); + rwkv_v2_set_parameter(¶meters, "head.weight", &(model->head)); // Verify order of dimensions struct ggml_v2_tensor * emb = model->emb; - RWKV_ASSERT_NULL(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); - RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); + RWKV_V2_ASSERT_NULL(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_V2_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); + RWKV_V2_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); int32_t n_embed = model->n_embed; int32_t n_layer = model->n_layer; @@ -381,7 +381,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct ggml_v2_tensor * x = ggml_v2_get_rows(ctx, model->emb, token_index); // x = self.layer_norm(x, self.w.blocks[0].ln0) - x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); + x = rwkv_v2_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); // We collect parts of new state here. Each part is (n_embed) vector. struct ggml_v2_tensor ** state_parts = new ggml_v2_tensor * [n_layer * 5]; @@ -392,7 +392,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // RWKV/time mixing { // self.layer_norm(x, self.w.blocks[i].ln1) - struct ggml_v2_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + struct ggml_v2_tensor * x0 = rwkv_v2_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); // state[5 * i + 1] struct ggml_v2_tensor * x_prev = ggml_v2_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float)); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) @@ -401,23 +401,23 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct ggml_v2_tensor * xk = ggml_v2_add( ctx, ggml_v2_mul(ctx, x0, layer.att_time_mix_k), - ggml_v2_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_v2_mul(ctx, x_prev, rwkv_v2_1_minus_x(ctx, layer.att_time_mix_k)) ); struct ggml_v2_tensor * xv = ggml_v2_add( ctx, ggml_v2_mul(ctx, x0, layer.att_time_mix_v), - ggml_v2_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_v2_mul(ctx, x_prev, rwkv_v2_1_minus_x(ctx, layer.att_time_mix_v)) ); struct ggml_v2_tensor * xr = ggml_v2_add( ctx, ggml_v2_mul(ctx, x0, layer.att_time_mix_r), - ggml_v2_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_v2_mul(ctx, x_prev, rwkv_v2_1_minus_x(ctx, layer.att_time_mix_r)) ); // state[5 * i + 1] = x state_parts[5 * i + 1] = x0; // r = torch.sigmoid(rw @ xr) - struct ggml_v2_tensor * r = rwkv_sigmoid( + struct ggml_v2_tensor * r = rwkv_v2_sigmoid( ctx, ggml_v2_mul_mat(ctx, layer.att_receptance, xr) ); @@ -436,11 +436,11 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // ww = time_first + k struct ggml_v2_tensor * ww = ggml_v2_add(ctx, layer.att_time_first, k); // qq = torch.maximum(pp, ww) - struct ggml_v2_tensor * qq = rwkv_max(ctx, pp, ww); + struct ggml_v2_tensor * qq = rwkv_v2_max(ctx, pp, ww); // e1 = torch.exp(pp - qq) - struct ggml_v2_tensor * e1 = rwkv_exp(ctx, ggml_v2_sub(ctx, pp, qq)); + struct ggml_v2_tensor * e1 = rwkv_v2_exp(ctx, ggml_v2_sub(ctx, pp, qq)); // e2 = torch.exp(ww - qq) - struct ggml_v2_tensor * e2 = rwkv_exp(ctx, ggml_v2_sub(ctx, ww, qq)); + struct ggml_v2_tensor * e2 = rwkv_v2_exp(ctx, ggml_v2_sub(ctx, ww, qq)); // a = e1 * aa + e2 * v struct ggml_v2_tensor * a = ggml_v2_add( ctx, @@ -458,11 +458,11 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // ww = pp + time_decay ww = ggml_v2_add(ctx, pp, layer.att_time_decay); // qq = torch.maximum(ww, k) - qq = rwkv_max(ctx, ww, k); + qq = rwkv_v2_max(ctx, ww, k); // e1 = torch.exp(ww - qq) - e1 = rwkv_exp(ctx, ggml_v2_sub(ctx, ww, qq)); + e1 = rwkv_v2_exp(ctx, ggml_v2_sub(ctx, ww, qq)); // e2 = torch.exp(k - qq) - e2 = rwkv_exp(ctx, ggml_v2_sub(ctx, k, qq)); + e2 = rwkv_v2_exp(ctx, ggml_v2_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v state_parts[5 * i + 2] = ggml_v2_add( ctx, @@ -492,7 +492,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // FFN/channel mixing { // self.layer_norm(x, self.w.blocks[i].ln2) - struct ggml_v2_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); + struct ggml_v2_tensor * x0 = rwkv_v2_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); // state[5 * i + 0] struct ggml_v2_tensor * x_prev = ggml_v2_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float)); // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) @@ -500,18 +500,18 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct ggml_v2_tensor * xk = ggml_v2_add( ctx, ggml_v2_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_v2_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_v2_mul(ctx, x_prev, rwkv_v2_1_minus_x(ctx, layer.ffn_time_mix_k)) ); struct ggml_v2_tensor * xr = ggml_v2_add( ctx, ggml_v2_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_v2_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_v2_mul(ctx, x_prev, rwkv_v2_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // state[5 * i + 0] = x state_parts[5 * i + 0] = x0; // r = torch.sigmoid(rw @ xr) - struct ggml_v2_tensor * r = rwkv_sigmoid( + struct ggml_v2_tensor * r = rwkv_v2_sigmoid( ctx, ggml_v2_mul_mat(ctx, layer.ffn_receptance, xr) ); @@ -534,7 +534,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr } // x = self.layer_norm(x, self.w.ln_out) - x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); + x = rwkv_v2_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); // x = (self.w.head.weight @ x).float() struct ggml_v2_tensor * logits = ggml_v2_mul_mat(ctx, model->head, x); @@ -549,7 +549,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr graph->n_threads = n_threads; - struct rwkv_context * rwkv_ctx = (struct rwkv_context *) calloc(1, sizeof(struct rwkv_context)); + struct rwkv_v2_context * rwkv_ctx = (struct rwkv_v2_context *) calloc(1, sizeof(struct rwkv_v2_context)); rwkv_ctx->model = model; rwkv_ctx->token_index = token_index; rwkv_ctx->state = state; @@ -560,23 +560,23 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr return rwkv_ctx; } -uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_v2_get_state_buffer_element_count(struct rwkv_v2_context * ctx) { return ctx->model->n_layer * 5 * ctx->model->n_embed; } -uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_v2_get_logits_buffer_element_count(struct rwkv_v2_context * ctx) { return ctx->model->n_vocab; } -bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { - RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); - RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); +bool rwkv_v2_eval(struct rwkv_v2_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { + RWKV_V2_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); + RWKV_V2_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); int32_t n_layer = ctx->model->n_layer; int32_t n_embed = ctx->model->n_embed; int32_t n_vocab = ctx->model->n_vocab; - RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); + RWKV_V2_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); ggml_v2_set_i32_1d(ctx->token_index, 0, token); @@ -607,7 +607,7 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float return true; } -void rwkv_free(struct rwkv_context * ctx) { +void rwkv_v2_free(struct rwkv_v2_context * ctx) { ctx->model->layers.~vector(); free(ctx->model); delete[] ctx->state_parts; @@ -616,14 +616,14 @@ void rwkv_free(struct rwkv_context * ctx) { free(ctx); } -bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { - int32_t format_type = format_name_to_format_type(format_name); +bool rwkv_v2_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { + int32_t format_type = rwkv_v2_format_name_to_format_type(format_name); - RWKV_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name); + RWKV_V2_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name); ggml_v2_type type = FORMAT_TYPE_TO_GGML_V2_TYPE[format_type]; - RWKV_ASSERT_FALSE(type != GGML_V2_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); + RWKV_V2_ASSERT_FALSE(type != GGML_V2_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); // Needed to initialize FP16 lookup table { @@ -635,21 +635,21 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode printf("Loading model from '%s'\n", model_file_path_in); auto finp = std::ifstream(model_file_path_in, std::ios::binary); - RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in); + RWKV_V2_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in); auto fout = std::ofstream(model_file_path_out, std::ios::binary); - RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out); + RWKV_V2_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out); // Process header { uint32_t magic; finp.read((char *) &magic, sizeof(magic)); - RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + RWKV_V2_ASSERT_FALSE(magic == RWKV_V2_FILE_MAGIC, "Unexpected magic value %d", magic); fout.write((char *) &magic, sizeof(magic)); uint32_t format_version; finp.read((char *) &format_version, sizeof(format_version)); - RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version); + RWKV_V2_ASSERT_FALSE(format_version == RWKV_V2_FILE_VERSION, "Unsupported file version %d", format_version); fout.write((char *) &format_version, sizeof(format_version)); int32_t n_vocab; @@ -662,7 +662,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode finp.read((char *) &n_layer, sizeof(n_layer)); finp.read((char *) &data_type, sizeof(data_type)); - RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); + RWKV_V2_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); data_type = format_type; @@ -698,11 +698,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode break; } - RWKV_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type); + RWKV_V2_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < RWKV_V2_FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type); ggml_v2_type parameter_ggml_v2_type = FORMAT_TYPE_TO_GGML_V2_TYPE[parameter_data_type]; - RWKV_ASSERT_FALSE(parameter_ggml_v2_type != GGML_V2_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type); + RWKV_V2_ASSERT_FALSE(parameter_ggml_v2_type != GGML_V2_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type); int32_t nelements = 1; int32_t ne[2] = { 1, 1 }; @@ -728,7 +728,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode name != std::string("head.weight"); if (quantize) { - RWKV_ASSERT_FALSE( + RWKV_V2_ASSERT_FALSE( parameter_data_type == 0 || parameter_data_type == 1, "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", parameter_data_type @@ -844,7 +844,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode return true; } -const char * rwkv_get_system_info_string(void) { +const char * rwkv_v2_get_system_info_string(void) { static std::string s; s = ""; diff --git a/otherarch/rwkv_v2.h b/otherarch/rwkv_v2.h index 24b9535dd..ea8388f24 100644 --- a/otherarch/rwkv_v2.h +++ b/otherarch/rwkv_v2.h @@ -1,56 +1,56 @@ -#ifndef RWKV_H -#define RWKV_H +#ifndef RWKV_H2 +#define RWKV_H2 #include #include #include -#ifdef RWKV_SHARED +#ifdef RWKV_SHARED2 # if defined(_WIN32) && !defined(__MINGW32__) # ifdef RWKV_BUILD -# define RWKV_API __declspec(dllexport) +# define RWKV_V2_API __declspec(dllexport) # else -# define RWKV_API __declspec(dllimport) +# define RWKV_V2_API __declspec(dllimport) # endif # else -# define RWKV_API __attribute__ ((visibility ("default"))) +# define RWKV_V2_API __attribute__ ((visibility ("default"))) # endif #else -# define RWKV_API +# define RWKV_V2_API #endif // 'ggmf' in hex. -#define RWKV_FILE_MAGIC 0x67676d66 -#define RWKV_FILE_VERSION 100 +#define RWKV_V2_FILE_MAGIC 0x67676d66 +#define RWKV_V2_FILE_VERSION 100 #ifdef __cplusplus extern "C" { #endif - struct rwkv_context; + struct rwkv_v2_context; // Loads the model from a file and prepares it for inference. // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads); + RWKV_V2_API struct rwkv_v2_context * rwkv_v2_init_from_file(const char * model_file_path, uint32_t n_threads); // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. // - token: next token index, in range 0 <= token < n_vocab. - // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. - RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); + // - state_in: FP32 buffer of size rwkv_v2_get_state_buffer_element_count; or NULL, if this is a first pass. + // - state_out: FP32 buffer of size rwkv_v2_get_state_buffer_element_count. This buffer will be written to. + // - logits_out: FP32 buffer of size rwkv_v2_get_logits_buffer_element_count. This buffer will be written to. + RWKV_V2_API bool rwkv_v2_eval(struct rwkv_v2_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); // Returns count of FP32 elements in state buffer. - RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); + RWKV_V2_API uint32_t rwkv_v2_get_state_buffer_element_count(struct rwkv_v2_context * ctx); // Returns count of FP32 elements in logits buffer. - RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); + RWKV_V2_API uint32_t rwkv_v2_get_logits_buffer_element_count(struct rwkv_v2_context * ctx); // Frees all allocated memory and the context. - RWKV_API void rwkv_free(struct rwkv_context * ctx); + RWKV_V2_API void rwkv_v2_free(struct rwkv_v2_context * ctx); // Quantizes FP32 or FP16 model to one of quantized formats. // Returns false on any error. Error messages would be printed to stderr. @@ -64,10 +64,10 @@ extern "C" { // - Q5_0 // - Q5_1 // - Q8_0 - RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name); + RWKV_V2_API bool rwkv_v2_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name); // Returns system information string. - RWKV_API const char * rwkv_get_system_info_string(void); + RWKV_V2_API const char * rwkv_v2_get_system_info_string(void); #ifdef __cplusplus } diff --git a/otherarch/rwkv_v3.cpp b/otherarch/rwkv_v3.cpp new file mode 100644 index 000000000..04d228eee --- /dev/null +++ b/otherarch/rwkv_v3.cpp @@ -0,0 +1,948 @@ +//adapted from RWKV.cpp repo under MIT license +// https://github.com/saharNooby/rwkv.cpp + +#include "otherarch.h" + +#include "rwkv_v3.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // fstat + +#ifdef WIN32 +#define stat64 _stat64 +#define fstat64 _fstat64 +#endif + +// --- Error handling --- + +enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE; +bool global_print_errors = true; + +enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) { + return static_cast(static_cast(a) | static_cast(b)); +} + +enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) { + return a = a | b; +} + +// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + if (global_print_errors) { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + } \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL. +#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + if (ctx->print_errors) { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + } \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL. +#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) \ + if (!(x)) { \ + ctx->last_error |= ERR_VAL; \ + return RET_VAL; \ + } + +#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) +#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) +#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__) +#define RWKV_CTX_ASSERT_NULL_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, NULL, x, __VA_ARGS__) + +#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x) +#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x) +#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) +#define RWKV_CTX_ASSERT_NULL(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, NULL, x) + +// --- Utilities --- + +// Reads single int32 value from a file. +bool read_int32(FILE * file, int32_t * dest, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file (%s)", name); + return true; +} + +// Reads single uint32 value from a file. +bool read_uint32(FILE * file, uint32_t * dest, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file (%s)", name); + return true; +} + +// Writes single int32 value to a file. +bool write_int32(FILE * file, int32_t value, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(int32_t), 1, file), "Failed to write an int32 value to a file (%s)", name); + return true; +} + +// Writes single uint32 value to a file. +bool write_uint32(FILE * file, uint32_t value, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(uint32_t), 1, file), "Failed to write a uint32 value to a file (%s)", name); + return true; +} + +#define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT + +#define FORMAT_TYPE_COUNT 10 + +static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_UNKNOWN, // Unused + GGML_TYPE_UNKNOWN, // Unused + GGML_TYPE_UNKNOWN, // Unused + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0 +}; + +static bool is_non_quantized_format_type(int32_t format_type) { + return format_type == 0 || format_type == 1; +} + +static bool is_quantized_format_type(int32_t format_type) { + return format_type == 2 || + format_type == 3 || + format_type == 7 || + format_type == 8 || + format_type == 9; +} + +static int32_t format_name_to_format_type(const char * format_name) { + if (strcmp(format_name, "Q4_0") == 0) return 2; + if (strcmp(format_name, "Q4_1") == 0) return 3; + if (strcmp(format_name, "Q5_0") == 0) return 7; + if (strcmp(format_name, "Q5_1") == 0) return 8; + if (strcmp(format_name, "Q8_0") == 0) return 9; + + return -1; +} + +// --- Model definition and loading utilities --- + +struct rwkv_layer { + struct ggml_tensor * ln1_weight; + struct ggml_tensor * ln1_bias; + + // RWKV, also called "attention" by the author. + struct ggml_tensor * att_time_mix_k; + struct ggml_tensor * att_time_mix_v; + struct ggml_tensor * att_time_mix_r; + struct ggml_tensor * att_time_first; + struct ggml_tensor * att_time_decay; + struct ggml_tensor * att_key; + struct ggml_tensor * att_value; + struct ggml_tensor * att_receptance; + struct ggml_tensor * att_output; + + struct ggml_tensor * ln2_weight; + struct ggml_tensor * ln2_bias; + + // FFN. + struct ggml_tensor * ffn_time_mix_k; + struct ggml_tensor * ffn_time_mix_r; + struct ggml_tensor * ffn_key; + struct ggml_tensor * ffn_value; + struct ggml_tensor * ffn_receptance; +}; + +struct rwkv_model { + uint32_t n_vocab; + uint32_t n_layer; + uint32_t n_embed; + // 0 for float32, 1 for float16. + int32_t data_type; + + struct ggml_tensor * emb; + + struct ggml_tensor * ln0_weight; + struct ggml_tensor * ln0_bias; + + std::vector layers; + + struct ggml_tensor * ln_out_weight; + struct ggml_tensor * ln_out_bias; + + struct ggml_tensor * head; +}; + +// Finds model parameter by key and sets it into dest. +// If the parameter was not found, returns false. +bool set_parameter(std::unordered_map * parameters, std::string key, struct ggml_tensor ** dest) { + struct ggml_tensor * parameter = (*parameters)[key]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_PARAM_MISSING, parameter != NULL, "Parameter %s not found in model file", key.c_str()); + *dest = parameter; + return true; +} + +// Finds block parameter by block index and key and sets it into dest. +// If the parameter was not found, returns false. +bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, std::string key, struct ggml_tensor ** dest) { + char full_key[128]; + sprintf(full_key, "blocks.%d.%s", block_index, key.c_str()); + return set_parameter(parameters, full_key, dest); +} + +// --- Operators --- + +void rwkv_exp_impl(const int n_cols, float * dest, const float * src) { + for (int i = 0; i < n_cols; i++) { + dest[i] = expf(src[i]); + } +} + +void rwkv_1_minus_x_impl(const int n_cols, float * dest, const float * src) { + for (int i = 0; i < n_cols; i++) { + dest[i] = 1.0F - src[i]; + } +} + +void rwkv_sigmoid_impl(const int n_cols, float * dest, const float * src) { + for (int i = 0; i < n_cols; i++) { + dest[i] = 1.0F / (1.0F + expf(-src[i])); + } +} + +void rwkv_max_impl(const int n_cols, float * dest, const float * src0, const float * src1) { + for (int i = 0; i < n_cols; i++) { + dest[i] = fmaxf(src0[i], src1[i]); + } +} + +struct ggml_tensor * rwkv_exp(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_unary_f32(ctx, x, rwkv_exp_impl); +} + +struct ggml_tensor * rwkv_1_minus_x(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_unary_f32(ctx, x, rwkv_1_minus_x_impl); +} + +struct ggml_tensor * rwkv_sigmoid(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_unary_f32(ctx, x, rwkv_sigmoid_impl); +} + +struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { + return ggml_map_binary_f32(ctx, x, y, rwkv_max_impl); +} + +struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { + // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` + // Looks like ggml_norm does the first part, we only need to apply weight & bias. + return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias); +} + +// --- Implementation --- + +struct rwkv_graph { + struct ggml_tensor * state; + std::unique_ptr state_parts; + struct ggml_tensor * token_index; + struct ggml_tensor * logits; + std::unique_ptr cgraph; +}; + +struct rwkv_context { + std::unique_ptr model; + struct ggml_context * ctx; + struct rwkv_graph graph; + enum rwkv_error_flags last_error; + bool print_errors; + + float * state_in = 0; //stores input state, or use null for a new state + float * state_out = 0; //stores address of output state buffer + float * logits_out = 0; //stores address of output logit buffer +}; + +void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { + bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; + *ptr = print_errors; +} + +bool rwkv_get_print_errors(struct rwkv_context * ctx) { + return ctx ? ctx->print_errors : global_print_errors; +} + +enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { + enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; + enum rwkv_error_flags value = *ptr; + *ptr = RWKV_ERROR_NONE; + return value; +} + +bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, const uint32_t n_threads, struct rwkv_graph * out) { + std::unique_ptr cgraph(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph"); + cgraph->n_threads = n_threads; + + size_t n_embed = model->n_embed, n_layer = model->n_layer; + struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); + + // We collect parts of new state here. Each part is (n_embed) vector. + std::unique_ptr state_parts(new(std::nothrow) ggml_tensor * [n_layer * 5]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, state_parts.get(), "Failed to allocate state parts"); + + // x = self.w.emb.weight[token] + struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); + + // x = self.layer_norm(x, self.w.blocks[0].ln0) + x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); + + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_layer layer = model->layers[i]; + size_t part_index = i * 5; + size_t state_part_size = n_embed * sizeof(float); + + // RWKV/time mixing + { + // self.layer_norm(x, self.w.blocks[i].ln1) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + // x0 = state[5 * i + 1] + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (part_index + 1) * state_part_size); + // aa = state[5 * i + 2] + struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (part_index + 2) * state_part_size); + // bb = state[5 * i + 3] + struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (part_index + 3) * state_part_size); + // pp = state[5 * i + 4] + struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (part_index + 4) * state_part_size); + + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ); + + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + struct ggml_tensor * xv = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ); + + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + // k = kw @ xk + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); + + // ww = time_first + k + struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); + // qq = torch.maximum(pp, ww) + struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); + // e1 = torch.exp(pp - qq) + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); + // e2 = torch.exp(ww - qq) + struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + + // a = e1 * aa + e2 * v + struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + // b = e1 * bb + e2 + struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + + // ww = pp + time_decay + ww = ggml_add_inplace(ctx, pp, layer.att_time_decay); + // qq = torch.maximum(ww, k) + qq = rwkv_max(ctx, ww, k); + // e1 = torch.exp(ww - qq) + e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + // e2 = torch.exp(k - qq) + e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); + + // state[5 * i + 1] = x0 + // state[5 * i + 2] = e1 * aa + e2 * v + // state[5 * i + 3] = e1 * bb + e2 + // state[5 * i + 4] = qq + state_parts[part_index + 1] = x0; + state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + state_parts[part_index + 4] = qq; + + // wkv = a / b + struct ggml_tensor * wkv = ggml_div(ctx, a, b); + + // ow @ (r * wkv) + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv))); + } + + // FFN/channel mixing + { + // self.layer_norm(x, self.w.blocks[i].ln2) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); + + // x_prev = state[5 * i + 0] + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, part_index * state_part_size); + + // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ); + + // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ); + + // state[5 * i + 0] = x + state_parts[part_index] = x0; + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + x = ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k))); + } + } + + // x = self.layer_norm(x, self.w.ln_out) + x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); + + // x = (self.w.head.weight @ x).float() + struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); + + ggml_build_forward_expand(cgraph.get(), logits); + + for (uint32_t i = 0; i < n_layer * 5; i++) { + ggml_build_forward_expand(cgraph.get(), state_parts[i]); + } + + out->state = state; + out->state_parts = std::move(state_parts); + out->token_index = token_index; + out->logits = logits; + out->cgraph = std::move(cgraph); + return true; +} + +struct rwkv_file_guard { + FILE * file; + ~rwkv_file_guard() { if (file) fclose(file); } +}; + +struct rwkv_ggml_guard { + struct ggml_context * ctx; + ~rwkv_ggml_guard() { if (ctx) ggml_free(ctx); } +}; + +struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { + global_last_error = RWKV_ERROR_NONE; + + FILE * file = fopen(file_path, "rb"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); + rwkv_file_guard file_guard { file }; + + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. + struct stat64 file_stat; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat64(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); + + int32_t magic; + RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &magic, "magic")); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + + int32_t version; + RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version")); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, "Unsupported file version %d", version); + + std::unique_ptr model(new(std::nothrow) struct rwkv_model()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); + + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); + + const char * unsupported_type_msg = "Models in %s format cannot be loaded anymore because the format was removed.\n" + "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info"; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_type_msg, "Q4_1_O"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 5, unsupported_type_msg, "Q4_2"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_type_msg, "Q4_3"); + + RWKV_ASSERT_NULL_MSG( + RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, + !is_quantized_format_type(model->data_type) || version >= RWKV_FILE_VERSION_1, + "The quantized model file was created with an old version of rwkv.cpp and can not be loaded anymore.\n" + "You need to requantize the model or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info" + ); + + size_t memory_required = file_stat.st_size + + // Intermediary vectors for calculation; there are around 100 calls to ggml + size_t(100) * model->n_embed * sizeof(float) + + // State, in and out + size_t(2) * 5 * model->n_layer * model->n_embed * sizeof(float) + + // Logits + size_t(model->n_vocab) * sizeof(float) + + // +256 MB just for any overhead + // TODO This is too much for smaller models; need a more proper and robust way of measuring required memory + size_t(256) * 1024 * 1024; + + struct ggml_context * ctx = ggml_init({ memory_required, NULL, false }); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, ctx, "Failed to allocate GGML context"); + rwkv_ggml_guard ggml_guard { ctx }; + + std::unordered_map parameters; + + while (true) { + int32_t dim_count, key_length, data_type; + RWKV_ASSERT_NULL_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, + fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), + "Failed to read an int32 value from a file (dim_count)" + ); + + if (feof(file)) { + break; + } + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type")); + + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, key_length > 0, "Non-positive key length %d", key_length); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); + + ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); + + struct ggml_tensor * tensor; + + if (dim_count == 1) { + int32_t x; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter length"); + tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); + } else { + int32_t x, y; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter width"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &y, "y"), "Failed to read parameter height"); + tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y); + } + + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + + std::string key(key_length, 0); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&key[0], key_length, 1, file) == 1, "Failed to read parameter key"); + + size_t nbytes = ggml_nbytes(tensor); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(tensor->data, nbytes, 1, file) == 1, "Failed to read parameter data"); + + parameters[key] = tensor; + } + + file_guard = { NULL }; // close file + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "emb.weight", &model->emb)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.weight", &model->ln0_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.bias", &model->ln0_bias)); + + model->layers.resize(model->n_layer); + + for (uint32_t i = 0; i < model->n_layer; i++) { + rwkv_layer * layer = &model->layers[i]; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.weight", &layer->ln1_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.bias", &layer->ln1_bias)); + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_k", &layer->att_time_mix_k)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_v", &layer->att_time_mix_v)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_r", &layer->att_time_mix_r)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_first", &layer->att_time_first)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_decay", &layer->att_time_decay)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.key.weight", &layer->att_key)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.value.weight", &layer->att_value)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.receptance.weight", &layer->att_receptance)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.output.weight", &layer->att_output)); + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.weight", &layer->ln2_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.bias", &layer->ln2_bias)); + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_k", &layer->ffn_time_mix_k)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_r", &layer->ffn_time_mix_r)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.key.weight", &layer->ffn_key)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.value.weight", &layer->ffn_value)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance)); + } + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head)); + + // Verify order of dimensions + struct ggml_tensor * emb = model->emb; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); + + // Build graph + struct rwkv_graph graph; + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph)); + + std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); + rwkv_ctx->model = std::move(model); + rwkv_ctx->ctx = ctx; + rwkv_ctx->graph = std::move(graph); + rwkv_ctx->last_error = RWKV_ERROR_NONE; + rwkv_ctx->print_errors = global_print_errors; + // Don't free ggml context + ggml_guard.ctx = NULL; + return rwkv_ctx.release(); +} + +uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { + return ctx->model->n_layer * 5 * ctx->model->n_embed; +} + +uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { + return ctx->model->n_vocab; +} + +bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { + ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; + + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, state_out != NULL, "state_out is NULL"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1); + + const struct rwkv_graph * graph = &ctx->graph; + size_t n_layer = ctx->model->n_layer; + size_t n_embed = ctx->model->n_embed; + + ggml_set_i32_1d(graph->token_index, 0, token); + + if (state_in == NULL) { + ggml_set_f32(graph->state, 0.0F); + + for (size_t i = 0; i < n_layer; i++) { + // state[5 * i + 4] = -1e30 + ggml_set_f32( + ggml_view_1d(ctx->ctx, graph->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), + -1e30F + ); + } + } else { + memcpy(graph->state->data, state_in, graph->state->ne[0] * sizeof(float)); + } + + ggml_graph_compute(ctx->ctx, graph->cgraph.get()); + + for (size_t i = 0; i < n_layer * 5; i++) { + struct ggml_tensor * part = graph->state_parts[i]; + memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); + } + + memcpy(logits_out, graph->logits->data, graph->logits->ne[0] * sizeof(float)); + + return true; +} + +void rwkv_free(struct rwkv_context * ctx) { + std::unique_ptr rwkv_ctx(ctx); + ggml_free(ctx->ctx); +} + +bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { + global_last_error = RWKV_ERROR_NONE; + + int32_t format_data_type = format_name_to_format_type(format_name); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_data_type != -1, "Unsupported format \"%s\"", format_name); + + ggml_type format_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[format_data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); + + // Needed to initialize FP16 lookup table + ggml_free(ggml_init({ 0, NULL, false })); + + printf("Loading model from '%s'\n", model_file_path_in); + + FILE * file_in = fopen(model_file_path_in, "rb"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_in, "Failed to open %s for reading", model_file_path_in); + FILE * file_out = fopen(model_file_path_out, "wb"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_out, "Failed to open %s for writing", model_file_path_out); + + rwkv_file_guard file_in_guard { file_in }; + rwkv_file_guard file_out_guard { file_out }; + + // Process header + { + uint32_t magic, version; + int32_t n_vocab, n_embed, n_layer, data_type; + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version")); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, + version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, + "Unsupported file version %d", + version + ); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer")); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type")); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, + is_non_quantized_format_type(data_type), + "Unsupported data type %d, only FP32 and FP16 can be quantized", + data_type + ); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic")); + // Always write latest version number when saving files + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, RWKV_FILE_VERSION_MAX, "version")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, format_data_type, "data_type")); + } + + // Process parameters + size_t total_size_orig = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims, key_length, parameter_data_type; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, + fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), + "Failed to read an int32 value from a file (n_dims)" + ); + + if (feof(file_in)) { + break; + } + + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length")); + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, ¶meter_data_type, "parameter_data_type")); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, + parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, + "Unsupported parameter data type %d", + parameter_data_type + ); + + ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, + parameter_ggml_type != GGML_TYPE_UNKNOWN, + "Unsupported parameter data type %d", + parameter_data_type + ); + + int32_t nelements, x, y; + + if (n_dims == 1) { + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); + y = 1; + nelements = x; + } else { + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &y, "y")); + nelements = x * y; + } + + std::string name(key_length, 0); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&name[0], key_length, 1, file_in) == 1, "Failed to read parameter key"); + + printf("%48s - [%5d, %5d], type = %6s ", name.data(), x, y, ggml_type_name(parameter_ggml_type)); + total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); + + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + bool quantize = n_dims == 2 && name != "emb.weight" && name != "head.weight"; + + if (quantize) { + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA_TYPE, + parameter_ggml_type == GGML_TYPE_F32 || parameter_data_type == GGML_TYPE_F16, + "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", + parameter_ggml_type + ); + + data_f32.resize(nelements); + + if (parameter_data_type == GGML_TYPE_F16) { + data_f16.resize(nelements); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, + "Failed to read parameter data" + ); + + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, + "Failed to read parameter data" + ); + } + + parameter_data_type = format_data_type; + parameter_ggml_type = format_ggml_type; + } else { + const size_t element_size = ggml_type_size(parameter_ggml_type); + data_u8.resize(nelements * element_size); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, + "Failed to read parameter data" + ); + } + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, key_length, "key_length")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, parameter_data_type, "parameter_data_type")); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x")); + + if (n_dims == 2) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y")); + } + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key"); + + if (quantize) { + printf("quantizing... "); + // For quantization + work.resize(nelements); + + // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! + std::vector hist_cur(1 << 4, 0); + + size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) = + format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : + format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : + format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 : + format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 : + format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 : + NULL; + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "Unsupported quantization type %d\n", format_ggml_type); + + size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data()); + total_size_new += cur_size; + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(work.data(), cur_size, 1, file_out) == 1, "Failed to write parameter data"); + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / float(nelements)); + } + + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(data_u8.data(), data_u8.size(), 1, file_out) == 1, "Failed to write parameter data"); + total_size_new += data_u8.size(); + } + } + + printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); + printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); + printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); + + int64_t sum_all = 0; + + for (int i = 0; i < (int) hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("hist: "); + + for (int i = 0; i < (int) hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / float(sum_all)); + } + + printf("\n"); + + return true; +} + +const char * rwkv_get_system_info_string(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} \ No newline at end of file diff --git a/otherarch/rwkv_v3.h b/otherarch/rwkv_v3.h new file mode 100644 index 000000000..8e09264d7 --- /dev/null +++ b/otherarch/rwkv_v3.h @@ -0,0 +1,125 @@ +#ifndef RWKV_H +#define RWKV_H + +#include +#include +#include + +#ifdef RWKV_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef RWKV_BUILD +# define RWKV_API __declspec(dllexport) +# else +# define RWKV_API __declspec(dllimport) +# endif +# else +# define RWKV_API __attribute__ ((visibility ("default"))) +# endif +#else +# define RWKV_API +#endif + +// 'ggmf' in hex. +#define RWKV_FILE_MAGIC 0x67676d66 + +#define RWKV_FILE_VERSION_0 100 +#define RWKV_FILE_VERSION_1 101 +#define RWKV_FILE_VERSION_MIN RWKV_FILE_VERSION_0 +#define RWKV_FILE_VERSION_MAX RWKV_FILE_VERSION_1 +// Default file version is the latest version. +#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX + +#ifdef __cplusplus +extern "C" { +#endif + + // Represents an error encountered during a function call. + // These are flags, so an actual value might contain multiple errors. + enum rwkv_error_flags { + RWKV_ERROR_NONE = 0, + + RWKV_ERROR_ARGS = 1 << 8, + RWKV_ERROR_FILE = 2 << 8, + RWKV_ERROR_MODEL = 3 << 8, + RWKV_ERROR_MODEL_PARAMS = 4 << 8, + RWKV_ERROR_GRAPH = 5 << 8, + RWKV_ERROR_CTX = 6 << 8, + + RWKV_ERROR_ALLOC = 1, + RWKV_ERROR_FILE_OPEN = 2, + RWKV_ERROR_FILE_STAT = 3, + RWKV_ERROR_FILE_READ = 4, + RWKV_ERROR_FILE_WRITE = 5, + RWKV_ERROR_FILE_MAGIC = 6, + RWKV_ERROR_FILE_VERSION = 7, + RWKV_ERROR_DATA_TYPE = 8, + RWKV_ERROR_UNSUPPORTED = 9, + RWKV_ERROR_SHAPE = 10, + RWKV_ERROR_DIMENSION = 11, + RWKV_ERROR_KEY = 12, + RWKV_ERROR_DATA = 13, + RWKV_ERROR_PARAM_MISSING = 14 + }; + + struct rwkv_context; + + // Sets whether errors are automatically printed to stderr. + // If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. + // - ctx: the context to suppress error messages for. + // If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors, + // as well as the default for new context. + // - print_errors: whether error messages should be automatically printed. + RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors); + + // Gets whether errors are automatically printed to stderr. + // - ctx: the context to retrieve the setting for, or NULL for the global setting. + RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx); + + // Retrieves and clears the error flags. + // - ctx: the context the retrieve the error for, or NULL for the global error. + RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); + + // Loads the model from a file and prepares it for inference. + // Returns NULL on any error. Error messages would be printed to stderr. + // - model_file_path: path to model file in ggml format. + // - n_threads: count of threads to use, must be positive. + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + + // Evaluates the model for a single token. + // Returns false on any error. Error messages would be printed to stderr. + // - token: next token index, in range 0 <= token < n_vocab. + // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); + + // Returns count of FP32 elements in state buffer. + RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx); + + // Returns count of FP32 elements in logits buffer. + RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx); + + // Frees all allocated memory and the context. + RWKV_API void rwkv_free(struct rwkv_context * ctx); + + // Quantizes FP32 or FP16 model to one of quantized formats. + // Returns false on any error. Error messages would be printed to stderr. + // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. + // - model_file_path_out: quantized model will be written here. + // - format_name: must be one of available format names below. + // Available format names: + // - Q4_0 + // - Q4_1 + // - Q5_0 + // - Q5_1 + // - Q8_0 + RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name); + + // Returns system information string. + RWKV_API const char * rwkv_get_system_info_string(void); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/otherarch/tools/convert_pt_rwkv.py b/otherarch/tools/convert_pt_rwkv.py index f42c316ce..28da6ecc1 100644 --- a/otherarch/tools/convert_pt_rwkv.py +++ b/otherarch/tools/convert_pt_rwkv.py @@ -1,48 +1,15 @@ -# Converts an RWKV model checkpoint to an rwkv.cpp compatible file. +# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file. # Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 # Get model checkpoints from https://huggingface.co/BlinkDL +# See FILE_FORMAT.md for the documentation on the file format. -# File format: -# -# RWKVModelFile { -# // All ints and floats are in machine byte order. -# // Magic is "ggml" string bytes. -# int32 magic = 0x67676d66; -# int32 version = 100; -# int32 n_vocab; -# int32 n_embed; -# int32 n_layer; -# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O. -# int32 data_type; -# // Read until EOF. -# Parameter[] parameters; -# } -# -# Parameter { -# int32 dim_count; -# int32 key_length; -# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O. -# int32 data_type; -# // Compared to PyTorch's tensor.shape, dimension order is reversed here! -# int32[dim_count] shape; -# // Keys are like "emb.weight", "block.0.ln1.weight". -# uint8[key_length] key_utf8; -# // float32: 4 * element_count bytes. -# // float16: 2 * element_count bytes. -# // Q4_0: element_count / 32 * 20 bytes. -# // Q4_1: element_count / 32 * 24 bytes. -# // Q4_1_O: element_count / 32 * 24 bytes. -# byte[] data; -# } - -import os import argparse import struct import torch from typing import Dict def parse_args(): - parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint to an rwkv.cpp compatible file') + parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file') parser.add_argument('src_path', help='Path to PyTorch checkpoint file') parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') @@ -71,8 +38,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t '=iiiiii', # Magic: 'ggmf' in hex 0x67676d66, - # llama.cpp uses file versions 1+, let's use 100+ for rwkv.cpp - 100, + 101, n_vocab, n_embed, n_layer, @@ -129,53 +95,5 @@ def main() -> None: print('Done') -# --- Tests --- - -def test() -> None: - test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' - - try: - state_dict: Dict[str, torch.Tensor] = { - 'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), - 'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) - } - - write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') - - with open(test_file_path, 'rb') as input: - actual_bytes: bytes = input.read() - - expected_bytes: bytes = struct.pack( - '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', - 0x67676d66, - 100, - 3, - 2, - 1, - 0, - # emb.weight - 2, - 10, - 0, - 2, 3, - 'emb.weight'.encode('utf-8'), - 1.0, 2.0, 3.0, - 4.0, 5.0, 6.0, - # blocks.0.ln1.weight - 1, - 19, - 0, - 1, - 'blocks.0.ln1.weight'.encode('utf-8'), - 1.0 - ) - - assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' - - print('All tests pass') - finally: - if os.path.isfile(test_file_path): - os.remove(test_file_path) - if __name__ == "__main__": - main() + main() \ No newline at end of file