resolve comments

This commit is contained in:
Sourab Mangrulkar 2024-03-01 11:09:35 +05:30
parent d62ce1c6b4
commit 10aa6e927e
3 changed files with 24 additions and 18 deletions

View file

@ -98,7 +98,7 @@ class Model:
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None:
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
@ -220,6 +220,8 @@ class Model:
return NomicBertModel
if model_architecture == "GemmaForCausalLM":
return GemmaModel
if model_architecture == "Starcoder2ForCausalLM":
return StarCoderModel2
return Model
def _is_model_safetensors(self) -> bool:
@ -926,6 +928,10 @@ class StarCoderModel(Model):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
class StarCoderModel2(Model):
def set_vocab(self):
self._set_vocab_gpt2()
class RefactModel(Model):
def set_gguf_parameters(self):