diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 65165b764..a9cfa9ffb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2760,14 +2760,13 @@ class RwkvModel(Model): self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_head_count(0) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) - self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) - # temporarlily reuse mamba hparams - self.gguf_writer.add_ssm_inner_size(hidden_size) - self.gguf_writer.add_ssm_conv_kernel(3) - self.gguf_writer.add_ssm_state_size(head_size) + self.gguf_writer.add_wkv_head_size(head_size) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + self.gguf_writer.add_feed_forward_length(0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a6883b392..32b902480 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -133,6 +133,9 @@ class Keys: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + class WKV: + HEAD_SIZE = "{arch}.wkv.head_size" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6bc3782c3..0388db567 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -673,6 +673,9 @@ class GGUFWriter: def add_rescale_every_n_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) + def add_wkv_head_size(self, size: int) -> None: + self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama.cpp b/src/llama.cpp index bfc292f59..c755e728f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -333,6 +333,8 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_WKV_HEAD_SIZE, + LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, @@ -433,6 +435,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -2291,6 +2295,7 @@ struct llama_hparams { // for RWKV uint32_t rescale_every_n_layers = 0; + uint32_t wkv_head_size = 0; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -2355,6 +2360,9 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; + if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; + if (this->wkv_head_size != other.wkv_head_size) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; const float EPSILON = 1e-9f; @@ -2418,15 +2426,25 @@ struct llama_hparams { } uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings - // corresponds to Mamba's conv_states size - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + if (wkv_head_size != 0) { + // for RWKV models + return 2 * n_embd; + } else { + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + } } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } else { + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; + } } }; @@ -5888,12 +5906,8 @@ static void llm_load_hparams( case LLM_ARCH_RWKV: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); - - // TODO: Re-using mamba keys right now, but RWKV isn't state-space - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); } break; default: (void)0; }