From 9f28f73785ecf62fe2625809b09348f8c2dd7625 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 20:34:17 +0300 Subject: [PATCH] llm : read arch-specific KVs --- llama.cpp | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index b2c9b3396..5c0bf6190 100644 --- a/llama.cpp +++ b/llama.cpp @@ -344,6 +344,7 @@ struct LLM_TN { return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; } }; + // // gguf helpers // @@ -1497,12 +1498,12 @@ static void llm_load_hparams( int n_ctx, float rope_freq_base, float rope_freq_scale) { - auto & hparams = model.hparams; - struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(arch); + auto & hparams = model.hparams; + // get general kv GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); @@ -1514,8 +1515,6 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT)); - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; @@ -1523,12 +1522,37 @@ static void llm_load_hparams( // TODO: manually setting rope scale should override this // rope_freq_scale (inverse of the kv) is optional - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - if (ropescale != 1.0f) { - rope_freq_scale = 1.0f/ropescale; + { + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (ropescale != 1.0f) { + rope_freq_scale = 1.0f/ropescale; + } } + // sanity check for n_rot (optional) + { + hparams.n_rot = hparams.n_embd / hparams.n_head; + + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + + if (hparams.n_rot != hparams.n_embd / hparams.n_head) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + } + } + + // arch-specific KVs + switch (arch) { + case LLM_ARCH_LLAMA: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + } break; + case LLM_ARCH_FALCON: + { + } break; + default: (void)0; + }; + // TODO: generalize to non-LLaMA models switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1594,6 +1618,7 @@ static void llm_load_vocab( // determine vocab type { std::string tokenizer_name; + GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); if (tokenizer_name == "llama") {