common : improve ctv ctk cli argument
This commit is contained in:
parent
8faa1d4dd4
commit
ba35f29e81
3 changed files with 59 additions and 42 deletions
|
@ -145,6 +145,45 @@ static void common_params_handle_model_default(common_params & params) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::initializer_list<std::pair<const char *, ggml_type>> kv_cache_types = {
|
||||||
|
{"f32", GGML_TYPE_F32},
|
||||||
|
{"f16", GGML_TYPE_F16},
|
||||||
|
{"bf16", GGML_TYPE_BF16},
|
||||||
|
{"q8_0", GGML_TYPE_Q8_0},
|
||||||
|
{"q4_0", GGML_TYPE_Q4_0},
|
||||||
|
{"q4_1", GGML_TYPE_Q4_1},
|
||||||
|
{"iq4_nl", GGML_TYPE_IQ4_NL},
|
||||||
|
{"q5_0", GGML_TYPE_Q5_0},
|
||||||
|
{"q5_1", GGML_TYPE_Q5_1},
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||||
|
for (const auto & kv : kv_cache_types) {
|
||||||
|
if (kv.first == s) {
|
||||||
|
return kv.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unsupported cache type: " + s);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * kv_cache_type_to_str(const ggml_type t) {
|
||||||
|
for (const auto & kv : kv_cache_types) {
|
||||||
|
if (kv.second == t) {
|
||||||
|
return kv.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unsupported cache type: " + std::to_string(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string get_all_kv_cache_types() {
|
||||||
|
std::ostringstream msg;
|
||||||
|
size_t size = kv_cache_types.size();
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", ");
|
||||||
|
}
|
||||||
|
return msg.str();
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CLI argument parsing functions
|
// CLI argument parsing functions
|
||||||
//
|
//
|
||||||
|
@ -1174,18 +1213,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
|
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-ctk", "--cache-type-k"}, "TYPE",
|
{"-ctk", "--cache-type-k"}, "TYPE",
|
||||||
string_format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
|
string_format(
|
||||||
|
"KV cache data type for K\n"
|
||||||
|
"allowed values: %s\n"
|
||||||
|
"(default: %s)",
|
||||||
|
get_all_kv_cache_types().c_str(),
|
||||||
|
kv_cache_type_to_str(params.cache_type_k)
|
||||||
|
),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
// TODO: get the type right here
|
params.cache_type_k = kv_cache_type_from_str(value);
|
||||||
params.cache_type_k = value;
|
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_CACHE_TYPE_K"));
|
).set_env("LLAMA_ARG_CACHE_TYPE_K"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-ctv", "--cache-type-v"}, "TYPE",
|
{"-ctv", "--cache-type-v"}, "TYPE",
|
||||||
string_format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
|
string_format(
|
||||||
|
"KV cache data type for V\n"
|
||||||
|
"allowed values: %s\n"
|
||||||
|
"(default: %s)",
|
||||||
|
get_all_kv_cache_types().c_str(),
|
||||||
|
kv_cache_type_to_str(params.cache_type_v)
|
||||||
|
),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
// TODO: get the type right here
|
params.cache_type_v = kv_cache_type_from_str(value);
|
||||||
params.cache_type_v = value;
|
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
|
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
@ -1015,38 +1015,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||||
return mparams;
|
return mparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|
||||||
if (s == "f32") {
|
|
||||||
return GGML_TYPE_F32;
|
|
||||||
}
|
|
||||||
if (s == "f16") {
|
|
||||||
return GGML_TYPE_F16;
|
|
||||||
}
|
|
||||||
if (s == "bf16") {
|
|
||||||
return GGML_TYPE_BF16;
|
|
||||||
}
|
|
||||||
if (s == "q8_0") {
|
|
||||||
return GGML_TYPE_Q8_0;
|
|
||||||
}
|
|
||||||
if (s == "q4_0") {
|
|
||||||
return GGML_TYPE_Q4_0;
|
|
||||||
}
|
|
||||||
if (s == "q4_1") {
|
|
||||||
return GGML_TYPE_Q4_1;
|
|
||||||
}
|
|
||||||
if (s == "iq4_nl") {
|
|
||||||
return GGML_TYPE_IQ4_NL;
|
|
||||||
}
|
|
||||||
if (s == "q5_0") {
|
|
||||||
return GGML_TYPE_Q5_0;
|
|
||||||
}
|
|
||||||
if (s == "q5_1") {
|
|
||||||
return GGML_TYPE_Q5_1;
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::runtime_error("Unsupported cache type: " + s);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||||
auto cparams = llama_context_default_params();
|
auto cparams = llama_context_default_params();
|
||||||
|
|
||||||
|
@ -1081,8 +1049,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||||
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||||
}
|
}
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = params.cache_type_k;
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = params.cache_type_v;
|
||||||
|
|
||||||
return cparams;
|
return cparams;
|
||||||
}
|
}
|
||||||
|
|
|
@ -286,8 +286,8 @@ struct common_params {
|
||||||
bool warmup = true; // warmup run
|
bool warmup = true; // warmup run
|
||||||
bool check_tensors = false; // validate tensor data
|
bool check_tensors = false; // validate tensor data
|
||||||
|
|
||||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
std::string cache_type_v = "f16"; // KV cache data type for the V
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
|
||||||
// multimodal models (see examples/llava)
|
// multimodal models (see examples/llava)
|
||||||
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue