fix: add more cleanup and harmonization
This commit is contained in:
parent
349426546b
commit
f7d2e9105f
1 changed files with 3 additions and 3 deletions
|
@ -2742,10 +2742,10 @@ class MambaModel(Model):
|
||||||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||||
use_b_dt_norm = False
|
use_dt_b_c_norm = False
|
||||||
# For falconmamba we do apply RMS norm on B / DT and C layers
|
# For falconmamba we do apply RMS norm on B / DT and C layers
|
||||||
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
|
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
|
||||||
use_b_dt_norm = True
|
use_dt_b_c_norm = True
|
||||||
# Fail early for models which don't have a block expansion factor of 2
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
assert d_inner == 2 * d_model
|
assert d_inner == 2 * d_model
|
||||||
|
|
||||||
|
@ -2759,7 +2759,7 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_ssm_state_size(d_state)
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
self.gguf_writer.add_mamba_dt_b_c_rms(use_b_dt_norm) # For classic Mamba we don't apply rms norm on B / DT layers
|
self.gguf_writer.add_mamba_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
_tok_embd = None
|
_tok_embd = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue