convert-hf : support new metadata keys for Mamba

For the models available at
https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406
This commit is contained in:
Francis Couture-Harpin 2024-03-07 20:28:42 -05:00
parent 7cd5a1f986
commit d8024a486b

View file

@ -1884,14 +1884,15 @@ class MambaModel(Model):
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
def set_gguf_parameters(self):
d_model = self.hparams["d_model"]
d_conv = self.hparams.get("d_conv", 4)
d_inner = self.hparams.get("d_inner", 2 * d_model)
d_state = self.hparams.get("d_state", 16)
d_model = self.find_hparam(["hidden_size", "d_model"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model
@ -1906,7 +1907,7 @@ class MambaModel(Model):
self.gguf_writer.add_ssm_inner_length(d_inner)
self.gguf_writer.add_ssm_state_length(d_state)
self.gguf_writer.add_ssm_dt_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self):