diff --git a/llama.cpp b/llama.cpp index 17cd6cd33..9ec68cbdb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1045,6 +1045,11 @@ static void llama_model_load_internal( // LLaMAv2 // TODO: temporary until GGUF + //patch for llama2 gqa + if (model.type == e_model::MODEL_65B && hparams.n_mult >= 4096) { + fprintf(stderr, "%s: Applying KCPP Patch for 70B model, setting GQA to 8\n", __func__); + n_gqa = 8; + } LLAMA_ASSERT(hparams.n_head % n_gqa == 0); hparams.n_head_kv = hparams.n_head / n_gqa; if (model.type == e_model::MODEL_65B && n_gqa == 8) {