From 42a9c5a948a610bbc3dfafbf972ca5ca7efb4290 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 21:27:53 +0100 Subject: [PATCH] even better approach --- common/arg.cpp | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a1280f6e2..8cc085ac4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -145,41 +145,32 @@ static void common_params_handle_model_default(common_params & params) { } } -const std::initializer_list> kv_cache_types = { - {"f32", GGML_TYPE_F32}, - {"f16", GGML_TYPE_F16}, - {"bf16", GGML_TYPE_BF16}, - {"q8_0", GGML_TYPE_Q8_0}, - {"q4_0", GGML_TYPE_Q4_0}, - {"q4_1", GGML_TYPE_Q4_1}, - {"iq4_nl", GGML_TYPE_IQ4_NL}, - {"q5_0", GGML_TYPE_Q5_0}, - {"q5_1", GGML_TYPE_Q5_1}, +const std::initializer_list kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, }; static ggml_type kv_cache_type_from_str(const std::string & s) { - for (const auto & kv : kv_cache_types) { - if (kv.first == s) { - return kv.second; + for (const auto & type : kv_cache_types) { + if (ggml_type_name(type) == s) { + return type; } } throw std::runtime_error("Unsupported cache type: " + s); } -static const char * kv_cache_type_to_str(const ggml_type t) { - for (const auto & kv : kv_cache_types) { - if (kv.second == t) { - return kv.first; - } - } - throw std::runtime_error("Unsupported cache type: " + std::to_string(t)); -} - static std::string get_all_kv_cache_types() { std::ostringstream msg; size_t size = kv_cache_types.size(); for (size_t i = 0; i < size; i++) { - msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", "); + msg << ggml_type_name(*(kv_cache_types.begin()+i)) << (i+1 == size ? "" : ", "); } return msg.str(); } @@ -1218,7 +1209,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "allowed values: %s\n" "(default: %s)", get_all_kv_cache_types().c_str(), - kv_cache_type_to_str(params.cache_type_k) + ggml_type_name(params.cache_type_k) ), [](common_params & params, const std::string & value) { params.cache_type_k = kv_cache_type_from_str(value); @@ -1231,7 +1222,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "allowed values: %s\n" "(default: %s)", get_all_kv_cache_types().c_str(), - kv_cache_type_to_str(params.cache_type_v) + ggml_type_name(params.cache_type_v) ), [](common_params & params, const std::string & value) { params.cache_type_v = kv_cache_type_from_str(value);