Fix broken logic for parsing bool KV overrides

Fix issue where overrides didn't apply when key missing in GGUF metadata
Resolve merge changes
This commit is contained in:
KerfuffleV2 2023-11-18 03:07:03 -07:00
parent 2147421904
commit aa7cf3143b
2 changed files with 18 additions and 21 deletions

View file

@ -699,9 +699,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} else if (strncmp(sep, "bool:", 5) == 0) { } else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5; sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_BOOL; kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
if (std::strcmp(sep, "true")) { if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true; kvo.bool_value = true;
} else if (std::strcmp(sep, "false")) { } else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false; kvo.bool_value = false;
} else { } else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
@ -888,6 +888,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str()); printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" -ld LOGDIR, --logdir LOGDIR\n"); printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n"); printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf("\n"); printf("\n");
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_print_usage(); log_print_usage();

View file

@ -607,7 +607,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
} }
} }
static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) { static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
switch (type) { switch (type) {
@ -1895,16 +1895,13 @@ namespace GGUFMeta {
if (try_override<T>(target, override)) { if (try_override<T>(target, override)) {
return true; return true;
} }
if (k < 0) { return false; }
target = get_kv(ctx, k); target = get_kv(ctx, k);
return true; return true;
} }
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) { static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
const int kid = gguf_find_key(ctx, key); return set(ctx, gguf_find_key(ctx, key), target, override);
if (kid < 0) {
return false;
}
return set(ctx, kid, target, override);
} }
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) { static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
@ -2367,6 +2364,7 @@ static void llm_load_hparams(
llama_model_loader & ml, llama_model_loader & ml,
llama_model & model) { llama_model & model) {
auto & hparams = model.hparams; auto & hparams = model.hparams;
const gguf_context * ctx = ml.ctx_gguf;
// get metadata as string // get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) { for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
@ -2678,19 +2676,15 @@ static void llm_load_vocab(
} }
// Handle add_bos_token and add_eos_token // Handle add_bos_token and add_eos_token
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS); {
int kid = gguf_find_key(ctx, key.c_str()); bool temp = true;
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1; if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) { vocab.special_add_bos = int(temp);
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str()); }
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.special_add_eos = int(temp);
} }
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
kid = gguf_find_key(ctx, key.c_str());
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
} }
} }