From f2985a070b71744c0221119afd401cd7818da9ad Mon Sep 17 00:00:00 2001 From: vxiiduu <73044267+vxiiduu@users.noreply.github.com> Date: Fri, 1 Sep 2023 01:29:09 +1000 Subject: [PATCH] Add support for 34B GGML models --- otherarch/llama_v3.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/otherarch/llama_v3.cpp b/otherarch/llama_v3.cpp index 47bcc4879..f0c919864 100644 --- a/otherarch/llama_v3.cpp +++ b/otherarch/llama_v3.cpp @@ -80,6 +80,7 @@ enum e_model3 { MODEL_7B_3, MODEL_13B_3, MODEL_30B_3, + MODEL_34B_3, MODEL_65B_3, MODEL_70B_3, }; @@ -124,6 +125,7 @@ static std::map MEM_REQ_SCRATCH0_3(int n_ctx) { MODEL_7B_3, ((size_t) n_ctx / 16ull + 164ull) * MB3 }, { MODEL_13B_3, ((size_t) n_ctx / 12ull + 184ull) * MB3 }, { MODEL_30B_3, ((size_t) n_ctx / 9ull + 224ull) * MB3 }, + { MODEL_34B_3, ((size_t) n_ctx / 8ull + 250ull) * MB3 }, // guess { MODEL_65B_3, ((size_t) n_ctx / 6ull + 320ull) * MB3 }, // guess { MODEL_70B_3, ((size_t) n_ctx / 7ull + 320ull) * MB3 }, }; @@ -137,6 +139,7 @@ static const std::map & MEM_REQ_SCRATCH1_3() { MODEL_7B_3, 224ull * MB3 }, { MODEL_13B_3, 256ull * MB3 }, { MODEL_30B_3, 320ull * MB3 }, + { MODEL_34B_3, 38ull * MB3 }, // guess { MODEL_65B_3, 448ull * MB3 }, // guess { MODEL_70B_3, 448ull * MB3 }, }; @@ -151,6 +154,7 @@ static const std::map & MEM_REQ_EVAL_3() { MODEL_7B_3, 20ull * MB3 }, { MODEL_13B_3, 24ull * MB3 }, { MODEL_30B_3, 32ull * MB3 }, + { MODEL_34B_3, 38ull * MB3 }, // guess { MODEL_65B_3, 48ull * MB3 }, // guess { MODEL_70B_3, 48ull * MB3 }, }; @@ -166,6 +170,7 @@ static const std::map & VRAM_REQ_SCRATCH_BASE_3() { MODEL_7B_3, 512ull * kB3 }, { MODEL_13B_3, 640ull * kB3 }, { MODEL_30B_3, 768ull * kB3 }, + { MODEL_34B_3, 960ull * kB3 }, { MODEL_65B_3, 1360ull * kB3 }, { MODEL_70B_3, 1360ull * kB3 }, }; @@ -181,6 +186,7 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT_3() { MODEL_7B_3, 128ull }, { MODEL_13B_3, 160ull }, { MODEL_30B_3, 208ull }, + { MODEL_34B_3, 356ull }, { MODEL_65B_3, 320ull }, { MODEL_70B_3, 320ull }, }; @@ -1034,6 +1040,7 @@ static const char * llama_v3_model_type_name(e_model3 type) { case MODEL_7B_3: return "7B"; case MODEL_13B_3: return "13B"; case MODEL_30B_3: return "30B"; + case MODEL_34B_3: return "34B"; case MODEL_65B_3: return "65B"; case MODEL_70B_3: return "70B"; default: LLAMA_V3_ASSERT(false); @@ -1082,6 +1089,7 @@ static void llama_v3_model_load_internal( case 26: model.type = e_model3::MODEL_3B_3; break; case 32: model.type = e_model3::MODEL_7B_3; break; case 40: model.type = e_model3::MODEL_13B_3; break; + case 48: model.type = e_model3::MODEL_34B_3; break; case 60: model.type = e_model3::MODEL_30B_3; break; case 80: model.type = e_model3::MODEL_65B_3; break; default: @@ -1101,6 +1109,11 @@ static void llama_v3_model_load_internal( fprintf(stderr, "%s: Applying KCPP Patch for 70B model, setting GQA to 8\n", __func__); n_gqa = 8; } + + if (model.type == e_model3::MODEL_34B_3) { + fprintf(stderr, "%s: Applying KCPP Patch for 34B model, setting GQA to 8\n", __func__); + n_gqa = 8; + } LLAMA_V3_ASSERT(hparams.n_head % n_gqa == 0); hparams.n_head_kv = hparams.n_head / n_gqa; if (model.type == e_model3::MODEL_65B_3 && n_gqa == 8) {