Add support for 34B GGML models

This commit is contained in:
vxiiduu 2023-09-01 01:29:09 +10:00 committed by GitHub
parent b6914ebd04
commit f2985a070b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -80,6 +80,7 @@ enum e_model3 {
MODEL_7B_3, MODEL_7B_3,
MODEL_13B_3, MODEL_13B_3,
MODEL_30B_3, MODEL_30B_3,
MODEL_34B_3,
MODEL_65B_3, MODEL_65B_3,
MODEL_70B_3, MODEL_70B_3,
}; };
@ -124,6 +125,7 @@ static std::map<e_model3, size_t> MEM_REQ_SCRATCH0_3(int n_ctx)
{ MODEL_7B_3, ((size_t) n_ctx / 16ull + 164ull) * MB3 }, { MODEL_7B_3, ((size_t) n_ctx / 16ull + 164ull) * MB3 },
{ MODEL_13B_3, ((size_t) n_ctx / 12ull + 184ull) * MB3 }, { MODEL_13B_3, ((size_t) n_ctx / 12ull + 184ull) * MB3 },
{ MODEL_30B_3, ((size_t) n_ctx / 9ull + 224ull) * MB3 }, { MODEL_30B_3, ((size_t) n_ctx / 9ull + 224ull) * MB3 },
{ MODEL_34B_3, ((size_t) n_ctx / 8ull + 250ull) * MB3 }, // guess
{ MODEL_65B_3, ((size_t) n_ctx / 6ull + 320ull) * MB3 }, // guess { MODEL_65B_3, ((size_t) n_ctx / 6ull + 320ull) * MB3 }, // guess
{ MODEL_70B_3, ((size_t) n_ctx / 7ull + 320ull) * MB3 }, { MODEL_70B_3, ((size_t) n_ctx / 7ull + 320ull) * MB3 },
}; };
@ -137,6 +139,7 @@ static const std::map<e_model3, size_t> & MEM_REQ_SCRATCH1_3()
{ MODEL_7B_3, 224ull * MB3 }, { MODEL_7B_3, 224ull * MB3 },
{ MODEL_13B_3, 256ull * MB3 }, { MODEL_13B_3, 256ull * MB3 },
{ MODEL_30B_3, 320ull * MB3 }, { MODEL_30B_3, 320ull * MB3 },
{ MODEL_34B_3, 38ull * MB3 }, // guess
{ MODEL_65B_3, 448ull * MB3 }, // guess { MODEL_65B_3, 448ull * MB3 }, // guess
{ MODEL_70B_3, 448ull * MB3 }, { MODEL_70B_3, 448ull * MB3 },
}; };
@ -151,6 +154,7 @@ static const std::map<e_model3, size_t> & MEM_REQ_EVAL_3()
{ MODEL_7B_3, 20ull * MB3 }, { MODEL_7B_3, 20ull * MB3 },
{ MODEL_13B_3, 24ull * MB3 }, { MODEL_13B_3, 24ull * MB3 },
{ MODEL_30B_3, 32ull * MB3 }, { MODEL_30B_3, 32ull * MB3 },
{ MODEL_34B_3, 38ull * MB3 }, // guess
{ MODEL_65B_3, 48ull * MB3 }, // guess { MODEL_65B_3, 48ull * MB3 }, // guess
{ MODEL_70B_3, 48ull * MB3 }, { MODEL_70B_3, 48ull * MB3 },
}; };
@ -166,6 +170,7 @@ static const std::map<e_model3, size_t> & VRAM_REQ_SCRATCH_BASE_3()
{ MODEL_7B_3, 512ull * kB3 }, { MODEL_7B_3, 512ull * kB3 },
{ MODEL_13B_3, 640ull * kB3 }, { MODEL_13B_3, 640ull * kB3 },
{ MODEL_30B_3, 768ull * kB3 }, { MODEL_30B_3, 768ull * kB3 },
{ MODEL_34B_3, 960ull * kB3 },
{ MODEL_65B_3, 1360ull * kB3 }, { MODEL_65B_3, 1360ull * kB3 },
{ MODEL_70B_3, 1360ull * kB3 }, { MODEL_70B_3, 1360ull * kB3 },
}; };
@ -181,6 +186,7 @@ static const std::map<e_model3, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT_3()
{ MODEL_7B_3, 128ull }, { MODEL_7B_3, 128ull },
{ MODEL_13B_3, 160ull }, { MODEL_13B_3, 160ull },
{ MODEL_30B_3, 208ull }, { MODEL_30B_3, 208ull },
{ MODEL_34B_3, 356ull },
{ MODEL_65B_3, 320ull }, { MODEL_65B_3, 320ull },
{ MODEL_70B_3, 320ull }, { MODEL_70B_3, 320ull },
}; };
@ -1034,6 +1040,7 @@ static const char * llama_v3_model_type_name(e_model3 type) {
case MODEL_7B_3: return "7B"; case MODEL_7B_3: return "7B";
case MODEL_13B_3: return "13B"; case MODEL_13B_3: return "13B";
case MODEL_30B_3: return "30B"; case MODEL_30B_3: return "30B";
case MODEL_34B_3: return "34B";
case MODEL_65B_3: return "65B"; case MODEL_65B_3: return "65B";
case MODEL_70B_3: return "70B"; case MODEL_70B_3: return "70B";
default: LLAMA_V3_ASSERT(false); default: LLAMA_V3_ASSERT(false);
@ -1082,6 +1089,7 @@ static void llama_v3_model_load_internal(
case 26: model.type = e_model3::MODEL_3B_3; break; case 26: model.type = e_model3::MODEL_3B_3; break;
case 32: model.type = e_model3::MODEL_7B_3; break; case 32: model.type = e_model3::MODEL_7B_3; break;
case 40: model.type = e_model3::MODEL_13B_3; break; case 40: model.type = e_model3::MODEL_13B_3; break;
case 48: model.type = e_model3::MODEL_34B_3; break;
case 60: model.type = e_model3::MODEL_30B_3; break; case 60: model.type = e_model3::MODEL_30B_3; break;
case 80: model.type = e_model3::MODEL_65B_3; break; case 80: model.type = e_model3::MODEL_65B_3; break;
default: default:
@ -1101,6 +1109,11 @@ static void llama_v3_model_load_internal(
fprintf(stderr, "%s: Applying KCPP Patch for 70B model, setting GQA to 8\n", __func__); fprintf(stderr, "%s: Applying KCPP Patch for 70B model, setting GQA to 8\n", __func__);
n_gqa = 8; n_gqa = 8;
} }
if (model.type == e_model3::MODEL_34B_3) {
fprintf(stderr, "%s: Applying KCPP Patch for 34B model, setting GQA to 8\n", __func__);
n_gqa = 8;
}
LLAMA_V3_ASSERT(hparams.n_head % n_gqa == 0); LLAMA_V3_ASSERT(hparams.n_head % n_gqa == 0);
hparams.n_head_kv = hparams.n_head / n_gqa; hparams.n_head_kv = hparams.n_head / n_gqa;
if (model.type == e_model3::MODEL_65B_3 && n_gqa == 8) { if (model.type == e_model3::MODEL_65B_3 && n_gqa == 8) {