even better approach

This commit is contained in:
Xuan Son Nguyen 2024-12-12 21:27:53 +01:00
parent 6b3696013c
commit 42a9c5a948

View file

@ -145,41 +145,32 @@ static void common_params_handle_model_default(common_params & params) {
} }
} }
const std::initializer_list<std::pair<const char *, ggml_type>> kv_cache_types = { const std::initializer_list<ggml_type> kv_cache_types = {
{"f32", GGML_TYPE_F32}, GGML_TYPE_F32,
{"f16", GGML_TYPE_F16}, GGML_TYPE_F16,
{"bf16", GGML_TYPE_BF16}, GGML_TYPE_BF16,
{"q8_0", GGML_TYPE_Q8_0}, GGML_TYPE_Q8_0,
{"q4_0", GGML_TYPE_Q4_0}, GGML_TYPE_Q4_0,
{"q4_1", GGML_TYPE_Q4_1}, GGML_TYPE_Q4_1,
{"iq4_nl", GGML_TYPE_IQ4_NL}, GGML_TYPE_IQ4_NL,
{"q5_0", GGML_TYPE_Q5_0}, GGML_TYPE_Q5_0,
{"q5_1", GGML_TYPE_Q5_1}, GGML_TYPE_Q5_1,
}; };
static ggml_type kv_cache_type_from_str(const std::string & s) { static ggml_type kv_cache_type_from_str(const std::string & s) {
for (const auto & kv : kv_cache_types) { for (const auto & type : kv_cache_types) {
if (kv.first == s) { if (ggml_type_name(type) == s) {
return kv.second; return type;
} }
} }
throw std::runtime_error("Unsupported cache type: " + s); throw std::runtime_error("Unsupported cache type: " + s);
} }
static const char * kv_cache_type_to_str(const ggml_type t) {
for (const auto & kv : kv_cache_types) {
if (kv.second == t) {
return kv.first;
}
}
throw std::runtime_error("Unsupported cache type: " + std::to_string(t));
}
static std::string get_all_kv_cache_types() { static std::string get_all_kv_cache_types() {
std::ostringstream msg; std::ostringstream msg;
size_t size = kv_cache_types.size(); size_t size = kv_cache_types.size();
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", "); msg << ggml_type_name(*(kv_cache_types.begin()+i)) << (i+1 == size ? "" : ", ");
} }
return msg.str(); return msg.str();
} }
@ -1218,7 +1209,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"allowed values: %s\n" "allowed values: %s\n"
"(default: %s)", "(default: %s)",
get_all_kv_cache_types().c_str(), get_all_kv_cache_types().c_str(),
kv_cache_type_to_str(params.cache_type_k) ggml_type_name(params.cache_type_k)
), ),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.cache_type_k = kv_cache_type_from_str(value); params.cache_type_k = kv_cache_type_from_str(value);
@ -1231,7 +1222,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"allowed values: %s\n" "allowed values: %s\n"
"(default: %s)", "(default: %s)",
get_all_kv_cache_types().c_str(), get_all_kv_cache_types().c_str(),
kv_cache_type_to_str(params.cache_type_v) ggml_type_name(params.cache_type_v)
), ),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.cache_type_v = kv_cache_type_from_str(value); params.cache_type_v = kv_cache_type_from_str(value);