diff --git a/common/common.cpp b/common/common.cpp index f07b4d1a4..277244e5a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -235,7 +235,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } bool parse_kv_override(const char * data, std::vector & overrides) { - const char* sep = strchr(data, '='); + const char * sep = strchr(data, '='); if (sep == nullptr || sep - data >= 128) { fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); return false; @@ -247,18 +247,18 @@ bool parse_kv_override(const char * data, std::vector & if (strncmp(sep, "int:", 4) == 0) { sep += 4; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); + kvo.val_i64 = std::atol(sep); } else if (strncmp(sep, "float:", 6) == 0) { sep += 6; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); + kvo.val_f64 = std::atof(sep); } else if (strncmp(sep, "bool:", 5) == 0) { sep += 5; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; + kvo.val_bool = true; } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; + kvo.val_bool = false; } else { fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); return false; @@ -266,7 +266,7 @@ bool parse_kv_override(const char * data, std::vector & } else if (strncmp(sep, "str:", 4) == 0) { sep += 4; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.str_value, sep, 128); + strncpy(kvo.val_str, sep, 128); } else { fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); return false; @@ -276,7 +276,7 @@ bool parse_kv_override(const char * data, std::vector & } bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { - llama_sampling_params& sparams = params.sparams; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { if (++i >= argc) { diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index b6464be3d..4419c0471 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -305,14 +305,14 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.str_value, imatrix_file.c_str(), 128); + strncpy(kvo.val_str, imatrix_file.c_str(), 128); kv_overrides.emplace_back(std::move(kvo)); } if (!imatrix_dataset.empty()) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.str_value, imatrix_dataset.c_str(), 128); + strncpy(kvo.val_str, imatrix_dataset.c_str(), 128); kv_overrides.emplace_back(std::move(kvo)); } @@ -320,7 +320,7 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = imatrix_data.size(); + kvo.val_i64 = imatrix_data.size(); kv_overrides.emplace_back(std::move(kvo)); } @@ -328,7 +328,7 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = m_last_call; + kvo.val_i64 = m_last_call; kv_overrides.emplace_back(std::move(kvo)); } } diff --git a/llama.cpp b/llama.cpp index 9d9f7b4e1..14c7f6741 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2875,16 +2875,16 @@ namespace GGUFMeta { __func__, override_type_to_str(ovrd->tag), ovrd->key); switch (ovrd->tag) { case LLAMA_KV_OVERRIDE_TYPE_BOOL: { - LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); } break; case LLAMA_KV_OVERRIDE_TYPE_INT: { - LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); } break; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { - LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); } break; case LLAMA_KV_OVERRIDE_TYPE_STR: { - LLAMA_LOG_INFO("%s\n", ovrd->str_value); + LLAMA_LOG_INFO("%s\n", ovrd->val_str); } break; default: // Shouldn't be possible to end up here, but just in case... @@ -2903,7 +2903,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { - target = ovrd->bool_value; + target = ovrd->val_bool; return true; } return false; @@ -2913,7 +2913,7 @@ namespace GGUFMeta { static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { - target = ovrd->int_value; + target = ovrd->val_i64; return true; } return false; @@ -2923,17 +2923,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { - target = ovrd->float_value; - return true; - } - return false; - } - - template - static typename std::enable_if::value, char *>::type - try_override(T & target, const struct llama_model_kv_override * ovrd) { - if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { - target = ovrd->str_value; + target = ovrd->val_f64; return true; } return false; @@ -2942,12 +2932,11 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { - (void)target; - (void)ovrd; - if (!ovrd) { return false; } - // Currently, we should never end up here so it would be a bug if we do. - throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", - ovrd ? ovrd->key : "NULL")); + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; } static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { @@ -14276,13 +14265,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s for (auto & o : overrides) { if (o.key[0] == 0) break; if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out, o.key, o.float_value); + gguf_set_val_f32(ctx_out, o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - gguf_set_val_i32(ctx_out, o.key, o.int_value); + gguf_set_val_i32(ctx_out, o.key, o.val_i64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out, o.key, o.bool_value); + gguf_set_val_bool(ctx_out, o.key, o.val_bool); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { - gguf_set_val_str(ctx_out, o.key, o.str_value); + gguf_set_val_str(ctx_out, o.key, o.val_str); } else { LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); } diff --git a/llama.h b/llama.h index 73d9733f7..afe3c0466 100644 --- a/llama.h +++ b/llama.h @@ -199,13 +199,15 @@ extern "C" { }; struct llama_model_kv_override { - char key[128]; enum llama_model_kv_override_type tag; - char str_value[128]; + + char key[128]; + union { - int64_t int_value; - double float_value; - bool bool_value; + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; }; };