Add `rescale_every_n_layers` parameter

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-06 18:53:27 +08:00
parent 0784a0cf26
commit 8d498c7075
4 changed files with 16 additions and 3 deletions

View file

@ -2754,6 +2754,7 @@ class RwkvModel(Model):
head_size = self.hparams["head_size"] head_size = self.hparams["head_size"]
hidden_size = self.hparams["hidden_size"] hidden_size = self.hparams["hidden_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"] layer_norm_eps = self.hparams["layer_norm_epsilon"]
rescale_every_n_layers = self.hparams["rescale_every"]
# RWKV isn't context limited # RWKV isn't context limited
self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_context_length(1048576)
@ -2762,14 +2763,13 @@ class RwkvModel(Model):
self.gguf_writer.add_head_count(0) self.gguf_writer.add_head_count(0)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp self.gguf_writer.add_feed_forward_length(0) # required by llama.cpp
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
# temporarlily reuse mamba hparams # temporarlily reuse mamba hparams
self.gguf_writer.add_ssm_inner_size(hidden_size) self.gguf_writer.add_ssm_inner_size(hidden_size)
self.gguf_writer.add_ssm_conv_kernel(3) self.gguf_writer.add_ssm_conv_kernel(3)
self.gguf_writer.add_ssm_state_size(head_size) self.gguf_writer.add_ssm_state_size(head_size)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
new_name = self.map_tensor_name(name) new_name = self.map_tensor_name(name)
if not (new_name.endswith(".weight") or new_name.endswith(".bias")): if not (new_name.endswith(".weight") or new_name.endswith(".bias")):

View file

@ -94,6 +94,7 @@ class Keys:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
class Attention: class Attention:
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT = "{arch}.attention.head_count"

View file

@ -670,6 +670,9 @@ class GGUFWriter:
def add_expert_weights_scale(self, value: float) -> None: def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
def add_rescale_every_n_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

View file

@ -297,6 +297,7 @@ enum llm_kv {
LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_RESCALE_EVERY_N_LAYERS,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -391,11 +392,12 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -2287,6 +2289,9 @@ struct llama_hparams {
float f_attn_logit_softcapping = 50.0f; float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f; float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;
float rope_freq_scale_train; float rope_freq_scale_train;
@ -5883,6 +5888,7 @@ static void llm_load_hparams(
case LLM_ARCH_RWKV: case LLM_ARCH_RWKV:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
// TODO: Re-using mamba keys right now, but RWKV isn't state-space // TODO: Re-using mamba keys right now, but RWKV isn't state-space
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
@ -15130,6 +15136,9 @@ struct llm_build_context {
ffn_shift ffn_shift
) )
); );
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
x = ggml_scale_inplace(ctx0, x, 0.5F);
}
} }
// Something related to skipping tokens, specifics unclear // Something related to skipping tokens, specifics unclear