Add `wkv.head_size
` key for RWKV
so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
8d498c7075
commit
903089b5eb
4 changed files with 36 additions and 17 deletions
|
@ -2760,14 +2760,13 @@ class RwkvModel(Model):
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
self.gguf_writer.add_embedding_length(hidden_size)
|
self.gguf_writer.add_embedding_length(hidden_size)
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(block_count)
|
||||||
self.gguf_writer.add_head_count(0)
|
|
||||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||||
self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp
|
|
||||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||||
# temporarlily reuse mamba hparams
|
self.gguf_writer.add_wkv_head_size(head_size)
|
||||||
self.gguf_writer.add_ssm_inner_size(hidden_size)
|
|
||||||
self.gguf_writer.add_ssm_conv_kernel(3)
|
# required by llama.cpp, unused
|
||||||
self.gguf_writer.add_ssm_state_size(head_size)
|
self.gguf_writer.add_head_count(0)
|
||||||
|
self.gguf_writer.add_feed_forward_length(0)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
new_name = self.map_tensor_name(name)
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
|
@ -133,6 +133,9 @@ class Keys:
|
||||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||||
|
|
||||||
|
class WKV:
|
||||||
|
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
MODEL = "tokenizer.ggml.model"
|
MODEL = "tokenizer.ggml.model"
|
||||||
PRE = "tokenizer.ggml.pre"
|
PRE = "tokenizer.ggml.pre"
|
||||||
|
|
|
@ -673,6 +673,9 @@ class GGUFWriter:
|
||||||
def add_rescale_every_n_layers(self, count: int) -> None:
|
def add_rescale_every_n_layers(self, count: int) -> None:
|
||||||
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
|
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_wkv_head_size(self, size: int) -> None:
|
||||||
|
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
|
||||||
|
|
||||||
def add_layer_norm_eps(self, value: float) -> None:
|
def add_layer_norm_eps(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
@ -333,6 +333,8 @@ enum llm_kv {
|
||||||
LLM_KV_SSM_TIME_STEP_RANK,
|
LLM_KV_SSM_TIME_STEP_RANK,
|
||||||
LLM_KV_SSM_DT_B_C_RMS,
|
LLM_KV_SSM_DT_B_C_RMS,
|
||||||
|
|
||||||
|
LLM_KV_WKV_HEAD_SIZE,
|
||||||
|
|
||||||
LLM_KV_TOKENIZER_MODEL,
|
LLM_KV_TOKENIZER_MODEL,
|
||||||
LLM_KV_TOKENIZER_PRE,
|
LLM_KV_TOKENIZER_PRE,
|
||||||
LLM_KV_TOKENIZER_LIST,
|
LLM_KV_TOKENIZER_LIST,
|
||||||
|
@ -433,6 +435,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
||||||
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
||||||
|
|
||||||
|
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
||||||
|
|
||||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||||
|
@ -2291,6 +2295,7 @@ struct llama_hparams {
|
||||||
|
|
||||||
// for RWKV
|
// for RWKV
|
||||||
uint32_t rescale_every_n_layers = 0;
|
uint32_t rescale_every_n_layers = 0;
|
||||||
|
uint32_t wkv_head_size = 0;
|
||||||
|
|
||||||
float rope_attn_factor = 1.0f;
|
float rope_attn_factor = 1.0f;
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
|
@ -2355,6 +2360,9 @@ struct llama_hparams {
|
||||||
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
||||||
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
|
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
|
||||||
|
|
||||||
|
if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
|
||||||
|
if (this->wkv_head_size != other.wkv_head_size) return true;
|
||||||
|
|
||||||
if (this->dec_start_token_id != other.dec_start_token_id) return true;
|
if (this->dec_start_token_id != other.dec_start_token_id) return true;
|
||||||
|
|
||||||
const float EPSILON = 1e-9f;
|
const float EPSILON = 1e-9f;
|
||||||
|
@ -2418,16 +2426,26 @@ struct llama_hparams {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
|
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
|
||||||
// corresponds to Mamba's conv_states size
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||||
|
if (wkv_head_size != 0) {
|
||||||
|
// for RWKV models
|
||||||
|
return 2 * n_embd;
|
||||||
|
} else {
|
||||||
// TODO: maybe support other convolution strides than 1
|
// TODO: maybe support other convolution strides than 1
|
||||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
|
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
|
||||||
|
if (wkv_head_size != 0) {
|
||||||
|
// corresponds to RWKV's wkv_states size
|
||||||
|
return n_embd * wkv_head_size;
|
||||||
|
} else {
|
||||||
// corresponds to Mamba's ssm_states size
|
// corresponds to Mamba's ssm_states size
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
|
@ -5888,12 +5906,8 @@ static void llm_load_hparams(
|
||||||
case LLM_ARCH_RWKV:
|
case LLM_ARCH_RWKV:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||||
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
|
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
|
||||||
|
|
||||||
// TODO: Re-using mamba keys right now, but RWKV isn't state-space
|
|
||||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
|
||||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
|
||||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
|
||||||
} break;
|
} break;
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue