From 0a8ad8f1d89070729fd65b5e7a3ae1437e165c40 Mon Sep 17 00:00:00 2001 From: randxie Date: Fri, 28 Jul 2023 00:21:58 -0700 Subject: [PATCH] use n_embd_gqa instead of n_embd to handle llama-2 70B --- examples/save-load-state/save-load-state.cpp | 1 + llama.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 4c8688503..61c71c358 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -26,6 +26,7 @@ int main(int argc, char ** argv) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; + lparams.n_gqa = params.n_gqa; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; diff --git a/llama.cpp b/llama.cpp index 9a8ecdcf6..a4489773f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3663,7 +3663,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd; + const int n_embd = hparams.n_embd_gqa(); const int n_ctx = hparams.n_ctx; const size_t kv_size = kv_self.buf.size; @@ -3766,7 +3766,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd; + const int n_embd = hparams.n_embd_gqa(); const int n_ctx = hparams.n_ctx; size_t kv_size;