even better approach

This commit is contained in:
Xuan Son Nguyen 2024-12-12 21:27:53 +01:00
parent 6b3696013c
commit 42a9c5a948

View file

@ -145,41 +145,32 @@ static void common_params_handle_model_default(common_params & params) {
}
}
const std::initializer_list<std::pair<const char *, ggml_type>> kv_cache_types = {
{"f32", GGML_TYPE_F32},
{"f16", GGML_TYPE_F16},
{"bf16", GGML_TYPE_BF16},
{"q8_0", GGML_TYPE_Q8_0},
{"q4_0", GGML_TYPE_Q4_0},
{"q4_1", GGML_TYPE_Q4_1},
{"iq4_nl", GGML_TYPE_IQ4_NL},
{"q5_0", GGML_TYPE_Q5_0},
{"q5_1", GGML_TYPE_Q5_1},
const std::initializer_list<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_BF16,
GGML_TYPE_Q8_0,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
};
static ggml_type kv_cache_type_from_str(const std::string & s) {
for (const auto & kv : kv_cache_types) {
if (kv.first == s) {
return kv.second;
for (const auto & type : kv_cache_types) {
if (ggml_type_name(type) == s) {
return type;
}
}
throw std::runtime_error("Unsupported cache type: " + s);
}
static const char * kv_cache_type_to_str(const ggml_type t) {
for (const auto & kv : kv_cache_types) {
if (kv.second == t) {
return kv.first;
}
}
throw std::runtime_error("Unsupported cache type: " + std::to_string(t));
}
static std::string get_all_kv_cache_types() {
std::ostringstream msg;
size_t size = kv_cache_types.size();
for (size_t i = 0; i < size; i++) {
msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", ");
msg << ggml_type_name(*(kv_cache_types.begin()+i)) << (i+1 == size ? "" : ", ");
}
return msg.str();
}
@ -1218,7 +1209,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
kv_cache_type_to_str(params.cache_type_k)
ggml_type_name(params.cache_type_k)
),
[](common_params & params, const std::string & value) {
params.cache_type_k = kv_cache_type_from_str(value);
@ -1231,7 +1222,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
kv_cache_type_to_str(params.cache_type_v)
ggml_type_name(params.cache_type_v)
),
[](common_params & params, const std::string & value) {
params.cache_type_v = kv_cache_type_from_str(value);