From ba35f29e81fdea2d29b4cb0643b474f99436c409 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 21:09:22 +0100 Subject: [PATCH] common : improve ctv ctk cli argument --- common/arg.cpp | 61 ++++++++++++++++++++++++++++++++++++++++++----- common/common.cpp | 36 ++-------------------------- common/common.h | 4 ++-- 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b27567f3b..a1280f6e2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -145,6 +145,45 @@ static void common_params_handle_model_default(common_params & params) { } } +const std::initializer_list> kv_cache_types = { + {"f32", GGML_TYPE_F32}, + {"f16", GGML_TYPE_F16}, + {"bf16", GGML_TYPE_BF16}, + {"q8_0", GGML_TYPE_Q8_0}, + {"q4_0", GGML_TYPE_Q4_0}, + {"q4_1", GGML_TYPE_Q4_1}, + {"iq4_nl", GGML_TYPE_IQ4_NL}, + {"q5_0", GGML_TYPE_Q5_0}, + {"q5_1", GGML_TYPE_Q5_1}, +}; + +static ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & kv : kv_cache_types) { + if (kv.first == s) { + return kv.second; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + +static const char * kv_cache_type_to_str(const ggml_type t) { + for (const auto & kv : kv_cache_types) { + if (kv.second == t) { + return kv.first; + } + } + throw std::runtime_error("Unsupported cache type: " + std::to_string(t)); +} + +static std::string get_all_kv_cache_types() { + std::ostringstream msg; + size_t size = kv_cache_types.size(); + for (size_t i = 0; i < size; i++) { + msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", "); + } + return msg.str(); +} + // // CLI argument parsing functions // @@ -1174,18 +1213,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_NO_KV_OFFLOAD")); add_opt(common_arg( {"-ctk", "--cache-type-k"}, "TYPE", - string_format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()), + string_format( + "KV cache data type for K\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + kv_cache_type_to_str(params.cache_type_k) + ), [](common_params & params, const std::string & value) { - // TODO: get the type right here - params.cache_type_k = value; + params.cache_type_k = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_K")); add_opt(common_arg( {"-ctv", "--cache-type-v"}, "TYPE", - string_format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()), + string_format( + "KV cache data type for V\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + kv_cache_type_to_str(params.cache_type_v) + ), [](common_params & params, const std::string & value) { - // TODO: get the type right here - params.cache_type_v = value; + params.cache_type_v = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_V")); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index 6143516d2..3cd43ecdf 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1015,38 +1015,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { return mparams; } -static ggml_type kv_cache_type_from_str(const std::string & s) { - if (s == "f32") { - return GGML_TYPE_F32; - } - if (s == "f16") { - return GGML_TYPE_F16; - } - if (s == "bf16") { - return GGML_TYPE_BF16; - } - if (s == "q8_0") { - return GGML_TYPE_Q8_0; - } - if (s == "q4_0") { - return GGML_TYPE_Q4_0; - } - if (s == "q4_1") { - return GGML_TYPE_Q4_1; - } - if (s == "iq4_nl") { - return GGML_TYPE_IQ4_NL; - } - if (s == "q5_0") { - return GGML_TYPE_Q5_0; - } - if (s == "q5_1") { - return GGML_TYPE_Q5_1; - } - - throw std::runtime_error("Unsupported cache type: " + s); -} - struct llama_context_params common_context_params_to_llama(const common_params & params) { auto cparams = llama_context_default_params(); @@ -1081,8 +1049,8 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; } - cparams.type_k = kv_cache_type_from_str(params.cache_type_k); - cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + cparams.type_k = params.cache_type_k; + cparams.type_v = params.cache_type_v; return cparams; } diff --git a/common/common.h b/common/common.h index 95d20401d..0481720ab 100644 --- a/common/common.h +++ b/common/common.h @@ -286,8 +286,8 @@ struct common_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data - std::string cache_type_k = "f16"; // KV cache data type for the K - std::string cache_type_v = "f16"; // KV cache data type for the V + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K + ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT