From 5d054a42f96ec959fc9070ea83e160a8a2740225 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 16 Sep 2024 09:15:15 -0600 Subject: [PATCH] fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams Branch: GraniteLM Signed-off-by: Gabe Goodhart --- src/llama.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 9c6703aad..79df86cf8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5438,7 +5438,6 @@ static void llm_load_hparams( // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_GRANITE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5455,20 +5454,13 @@ static void llm_load_hparams( // granite uses a vocab with len 49152 case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; case 36: model.type = e_model::MODEL_8B; break; // granite - case 40: model.type = (hparams.n_vocab == 49152 || hparams.n_vocab == 49156) ? e_model::MODEL_3B : e_model::MODEL_13B; break; + case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } } - // Extra multipliers for Granite architecture - if (model.arch == LLM_ARCH_GRANITE) { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); - } } break; case LLM_ARCH_MINICPM: { @@ -6059,6 +6051,20 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GRANITE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; }