common : improve -ctv -ctk CLI arguments (#10806)
* common : improve ctv ctk cli argument * regenerate docs * even better approach * use std::vector
This commit is contained in:
parent
274ec65af6
commit
adffa6ffd5
5 changed files with 60 additions and 51 deletions
|
@ -145,6 +145,35 @@ static void common_params_handle_model_default(common_params & params) {
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<ggml_type> kv_cache_types = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
GGML_TYPE_BF16,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_IQ4_NL,
|
||||
GGML_TYPE_Q5_0,
|
||||
GGML_TYPE_Q5_1,
|
||||
};
|
||||
|
||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
for (const auto & type : kv_cache_types) {
|
||||
if (ggml_type_name(type) == s) {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unsupported cache type: " + s);
|
||||
}
|
||||
|
||||
static std::string get_all_kv_cache_types() {
|
||||
std::ostringstream msg;
|
||||
for (const auto & type : kv_cache_types) {
|
||||
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", ");
|
||||
}
|
||||
return msg.str();
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
@ -1174,18 +1203,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
|
||||
add_opt(common_arg(
|
||||
{"-ctk", "--cache-type-k"}, "TYPE",
|
||||
string_format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
|
||||
string_format(
|
||||
"KV cache data type for K\n"
|
||||
"allowed values: %s\n"
|
||||
"(default: %s)",
|
||||
get_all_kv_cache_types().c_str(),
|
||||
ggml_type_name(params.cache_type_k)
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
// TODO: get the type right here
|
||||
params.cache_type_k = value;
|
||||
params.cache_type_k = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_K"));
|
||||
add_opt(common_arg(
|
||||
{"-ctv", "--cache-type-v"}, "TYPE",
|
||||
string_format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
|
||||
string_format(
|
||||
"KV cache data type for V\n"
|
||||
"allowed values: %s\n"
|
||||
"(default: %s)",
|
||||
get_all_kv_cache_types().c_str(),
|
||||
ggml_type_name(params.cache_type_v)
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
// TODO: get the type right here
|
||||
params.cache_type_v = value;
|
||||
params.cache_type_v = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
|
||||
add_opt(common_arg(
|
||||
|
|
|
@ -1015,38 +1015,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|||
return mparams;
|
||||
}
|
||||
|
||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||
if (s == "f32") {
|
||||
return GGML_TYPE_F32;
|
||||
}
|
||||
if (s == "f16") {
|
||||
return GGML_TYPE_F16;
|
||||
}
|
||||
if (s == "bf16") {
|
||||
return GGML_TYPE_BF16;
|
||||
}
|
||||
if (s == "q8_0") {
|
||||
return GGML_TYPE_Q8_0;
|
||||
}
|
||||
if (s == "q4_0") {
|
||||
return GGML_TYPE_Q4_0;
|
||||
}
|
||||
if (s == "q4_1") {
|
||||
return GGML_TYPE_Q4_1;
|
||||
}
|
||||
if (s == "iq4_nl") {
|
||||
return GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
if (s == "q5_0") {
|
||||
return GGML_TYPE_Q5_0;
|
||||
}
|
||||
if (s == "q5_1") {
|
||||
return GGML_TYPE_Q5_1;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unsupported cache type: " + s);
|
||||
}
|
||||
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||
auto cparams = llama_context_default_params();
|
||||
|
||||
|
@ -1081,8 +1049,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|||
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
||||
return cparams;
|
||||
}
|
||||
|
|
|
@ -286,8 +286,8 @@ struct common_params {
|
|||
bool warmup = true; // warmup run
|
||||
bool check_tensors = false; // validate tensor data
|
||||
|
||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||
std::string cache_type_v = "f16"; // KV cache data type for the V
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
// multimodal models (see examples/llava)
|
||||
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue