convert-hf : support new metadata keys for Mamba
For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406
This commit is contained in:
parent
7cd5a1f986
commit
d8024a486b
1 changed files with 7 additions and 6 deletions
|
@ -1884,14 +1884,15 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
|
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
d_model = self.hparams["d_model"]
|
d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||||
d_conv = self.hparams.get("d_conv", 4)
|
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||||
d_inner = self.hparams.get("d_inner", 2 * d_model)
|
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||||
d_state = self.hparams.get("d_state", 16)
|
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
|
||||||
# ceiling division
|
# ceiling division
|
||||||
# ref: https://stackoverflow.com/a/17511341/22827863
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
|
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||||
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||||
|
|
||||||
# Fail early for models which don't have a block expansion factor of 2
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
assert d_inner == 2 * d_model
|
assert d_inner == 2 * d_model
|
||||||
|
@ -1906,7 +1907,7 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_ssm_inner_length(d_inner)
|
self.gguf_writer.add_ssm_inner_length(d_inner)
|
||||||
self.gguf_writer.add_ssm_state_length(d_state)
|
self.gguf_writer.add_ssm_state_length(d_state)
|
||||||
self.gguf_writer.add_ssm_dt_rank(dt_rank)
|
self.gguf_writer.add_ssm_dt_rank(dt_rank)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue