convert : optionally use d_conv and d_state from config.json for Mamba

This commit is contained in:
Francis Couture-Harpin 2024-01-29 08:27:09 -05:00
parent 54d3e48601
commit 74eea856bf

View file

@ -1850,15 +1850,19 @@ class MambaModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
d_model = self.hparams["d_model"] d_model = self.hparams["d_model"]
d_inner = self.hparams.get("d_inner", 2 * d_model)
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model
self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(128) # arbitrary value; it shouldn't be important for Mamba self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(2 * d_model) # d_inner self.gguf_writer.add_head_count(d_inner)
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_layer_norm_rms_eps(1e-5) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_key_length(4) # d_conv self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4))
self.gguf_writer.add_value_length(16) # d_state self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self): def write_tensors(self):