fix: address comments
This commit is contained in:
parent
a8109e352e
commit
184a4c676f
4 changed files with 12 additions and 13 deletions
|
@ -2741,8 +2741,7 @@ class MambaModel(Model):
|
|||
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
num_hidden_layers = self.find_hparam(["n_layer", "num_hidden_layers"])
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
use_b_dt_norm = False
|
||||
# For falconmamba we do apply RMS norm on B / DT and C layers
|
||||
if self.find_hparam(["model_type"]) in ["falcon_mamba"]:
|
||||
|
@ -2754,13 +2753,13 @@ class MambaModel(Model):
|
|||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_block_count(num_hidden_layers)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_mamba_b_dt_rms(use_b_dt_norm) # For classic Mamba we don't apply rms norm on B / DT layers
|
||||
self.gguf_writer.add_mamba_dt_b_c_rms(use_b_dt_norm) # For classic Mamba we don't apply rms norm on B / DT layers
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_tok_embd = None
|
||||
|
|
|
@ -130,7 +130,7 @@ class Keys:
|
|||
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||
STATE_SIZE = "{arch}.ssm.state_size"
|
||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||
B_DT_RMS = "{arch}.ssm.b_dt_rms"
|
||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
|
@ -1373,7 +1373,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
|||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||
KEY_SSM_B_DT_RMS = Keys.SSM.B_DT_RMS
|
||||
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||
|
|
|
@ -715,8 +715,8 @@ class GGUFWriter:
|
|||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_mamba_b_dt_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.B_DT_RMS.format(arch=self.arch), value)
|
||||
def add_mamba_dt_b_c_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||
|
|
|
@ -328,7 +328,7 @@ enum llm_kv {
|
|||
LLM_KV_SSM_CONV_KERNEL,
|
||||
LLM_KV_SSM_STATE_SIZE,
|
||||
LLM_KV_SSM_TIME_STEP_RANK,
|
||||
LLM_KV_SSM_B_DT_RMS,
|
||||
LLM_KV_SSM_DT_B_C_RMS,
|
||||
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_PRE,
|
||||
|
@ -2239,7 +2239,7 @@ struct llama_hparams {
|
|||
uint32_t ssm_d_inner = 0;
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
bool ssm_b_dt_rms = false;
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
|
@ -5055,7 +5055,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_B_DT_RMS, hparams.ssm_b_dt_rms, false);
|
||||
ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
|
@ -12166,7 +12166,7 @@ struct llm_build_context {
|
|||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
||||
const bool ssm_b_dt_rms = hparams.ssm_b_dt_rms;
|
||||
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
||||
// Use the same RMS norm as the final layer norm
|
||||
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||
|
||||
|
@ -12250,7 +12250,7 @@ struct llm_build_context {
|
|||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
|
||||
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
||||
if (ssm_b_dt_rms) {
|
||||
if (ssm_dt_b_c_rms) {
|
||||
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
|
||||
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
|
||||
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue