llama : pass KV cache type through API

This commit is contained in:
Georgi Gerganov 2023-12-05 15:40:23 +02:00
parent b881f630ca
commit 3ce30e07c9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 59 additions and 12 deletions

View file

@ -500,6 +500,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") {
params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
@ -844,6 +848,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" verbose print of the KV cache\n");
printf(" -nkvo, --no-kv-offload\n");
printf(" disable KV offload\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
printf(" -ctv TYPE, --cache-type-v TYPE\n");
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@ -908,6 +916,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
return mparams;
}
static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
throw std::runtime_error("Invalid cache type: " + s);
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
@ -930,6 +961,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.offload_kqv = !params.no_kv_offload;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
return cparams;
}