Add `wkv.head_size` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-07 10:35:40 +08:00
parent 8d498c7075
commit 903089b5eb
4 changed files with 36 additions and 17 deletions

View file

@ -2760,14 +2760,13 @@ class RwkvModel(Model):
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(0)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
# temporarlily reuse mamba hparams self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_ssm_inner_size(hidden_size)
self.gguf_writer.add_ssm_conv_kernel(3) # required by llama.cpp, unused
self.gguf_writer.add_ssm_state_size(head_size) self.gguf_writer.add_head_count(0)
self.gguf_writer.add_feed_forward_length(0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name) new_name = self.map_tensor_name(name)

View file

@ -133,6 +133,9 @@ class Keys:
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
class Tokenizer: class Tokenizer:
MODEL = "tokenizer.ggml.model" MODEL = "tokenizer.ggml.model"
PRE = "tokenizer.ggml.pre" PRE = "tokenizer.ggml.pre"

View file

@ -673,6 +673,9 @@ class GGUFWriter:
def add_rescale_every_n_layers(self, count: int) -> None: def add_rescale_every_n_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View file

@ -333,6 +333,8 @@ enum llm_kv {
LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_DT_B_C_RMS, LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE,
LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_PRE,
LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_LIST,
@ -433,6 +435,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -2291,6 +2295,7 @@ struct llama_hparams {
// for RWKV // for RWKV
uint32_t rescale_every_n_layers = 0; uint32_t rescale_every_n_layers = 0;
uint32_t wkv_head_size = 0;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;
@ -2355,6 +2360,9 @@ struct llama_hparams {
if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
if (this->wkv_head_size != other.wkv_head_size) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true;
const float EPSILON = 1e-9f; const float EPSILON = 1e-9f;
@ -2418,15 +2426,25 @@ struct llama_hparams {
} }
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size // corresponds to Mamba's conv_states size or RWKV's token_shift states size
// TODO: maybe support other convolution strides than 1 if (wkv_head_size != 0) {
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // for RWKV models
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; return 2 * n_embd;
} else {
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}
} }
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
// corresponds to Mamba's ssm_states size if (wkv_head_size != 0) {
return ssm_d_state * ssm_d_inner; // corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
} else {
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
} }
}; };
@ -5888,12 +5906,8 @@ static void llm_load_hparams(
case LLM_ARCH_RWKV: case LLM_ARCH_RWKV:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
// TODO: Re-using mamba keys right now, but RWKV isn't state-space
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
} break; } break;
default: (void)0; default: (void)0;
} }