From c2008b568f7697badd833708c9ca8805b46ec4e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Jan 2025 16:44:49 +0200 Subject: [PATCH] hparams : remove n_vocab --- src/llama-context.cpp | 9 +++++---- src/llama-hparams.h | 1 - src/llama-model.cpp | 12 ++++++------ src/llama.cpp | 8 ++------ 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4b195eaca..e20482516 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -469,11 +469,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; + const auto & vocab = lctx.model.vocab; const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); const auto n_batch = cparams.n_batch; - const auto n_vocab = hparams.n_vocab; + const auto n_vocab = vocab.n_vocab(); const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead @@ -540,7 +541,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { void llama_output_reorder(struct llama_context & ctx) { std::vector & out_ids = ctx.sbatch.out_ids; if (!out_ids.empty()) { - const uint32_t n_vocab = ctx.model.hparams.n_vocab; + const uint32_t n_vocab = ctx.model.vocab.n_vocab(); const uint32_t n_embd = ctx.model.hparams.n_embd; const int32_t n_outputs = ctx.n_outputs; @@ -724,7 +725,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } - return ctx->logits + j*ctx->model.hparams.n_vocab; + return ctx->logits + j*ctx->model.vocab.n_vocab(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -884,7 +885,7 @@ struct llama_data_write { } void write_logits(const struct llama_context * ctx) { - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_vocab()); write(&logits_size, sizeof(logits_size)); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index a29f20ec4..60362c5d9 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -30,7 +30,6 @@ struct llama_hparams { bool use_par_res; bool swin_norm; - uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_embd_features = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0cc6e09f0..1d06af60c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -402,9 +402,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // get general kv ml.get_key(LLM_KV_GENERAL_NAME, name, false); - // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); - // everything past this point is not vocab-related if (hparams.vocab_only) { return; @@ -500,6 +497,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_embd_head_v = 0; } + uint32_t n_vocab = 0; + + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: @@ -519,7 +520,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 26: type = LLM_TYPE_3B; break; case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 - case 32: type = hparams.n_vocab == 49152 ? LLM_TYPE_3B : (hparams.n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; case 36: type = LLM_TYPE_8B; break; // granite case 40: type = LLM_TYPE_13B; break; case 48: type = LLM_TYPE_34B; break; @@ -1365,7 +1366,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_ff = hparams.n_ff(); const int64_t n_embd_gqa = n_embd_v_gqa; - const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab = vocab.n_vocab(); const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; @@ -3494,7 +3495,6 @@ void llama_model::print_info() const { // hparams LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab (hp) = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); if (!hparams.vocab_only) { diff --git a/src/llama.cpp b/src/llama.cpp index 0ea250ad0..b98499f79 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -65,11 +65,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam model.load_stats(ml); model.print_info(); - if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE && - model.hparams.n_vocab != model.vocab.n_vocab()) { - throw std::runtime_error("vocab size mismatch"); - } - if (params.vocab_only) { LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); return 0; @@ -8342,6 +8337,7 @@ static int llama_decode_impl( const uint32_t n_tokens_all = batch.n_tokens; const auto & model = lctx.model; + const auto & vocab = model.vocab; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -8369,7 +8365,7 @@ static int llama_decode_impl( llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; - const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab = vocab.n_vocab(); uint32_t n_outputs = 0; uint32_t n_outputs_prev = 0;