Fix broken logic for parsing bool KV overrides
Fix issue where overrides didn't apply when key missing in GGUF metadata Resolve merge changes
This commit is contained in:
parent
2147421904
commit
aa7cf3143b
2 changed files with 18 additions and 21 deletions
|
@ -699,9 +699,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||
sep += 5;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
|
||||
if (std::strcmp(sep, "true")) {
|
||||
if (std::strcmp(sep, "true") == 0) {
|
||||
kvo.bool_value = true;
|
||||
} else if (std::strcmp(sep, "false")) {
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.bool_value = false;
|
||||
} else {
|
||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
||||
|
@ -888,6 +888,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
|
||||
printf(" -ld LOGDIR, --logdir LOGDIR\n");
|
||||
printf(" path under which to save YAML logs (no logging if unset)\n");
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||
printf("\n");
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_print_usage();
|
||||
|
|
30
llama.cpp
30
llama.cpp
|
@ -607,7 +607,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
|
|||
}
|
||||
}
|
||||
|
||||
static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
|
||||
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
||||
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
|
||||
|
||||
switch (type) {
|
||||
|
@ -1895,16 +1895,13 @@ namespace GGUFMeta {
|
|||
if (try_override<T>(target, override)) {
|
||||
return true;
|
||||
}
|
||||
if (k < 0) { return false; }
|
||||
target = get_kv(ctx, k);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
|
||||
const int kid = gguf_find_key(ctx, key);
|
||||
if (kid < 0) {
|
||||
return false;
|
||||
}
|
||||
return set(ctx, kid, target, override);
|
||||
return set(ctx, gguf_find_key(ctx, key), target, override);
|
||||
}
|
||||
|
||||
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
|
||||
|
@ -2367,6 +2364,7 @@ static void llm_load_hparams(
|
|||
llama_model_loader & ml,
|
||||
llama_model & model) {
|
||||
auto & hparams = model.hparams;
|
||||
const gguf_context * ctx = ml.ctx_gguf;
|
||||
|
||||
// get metadata as string
|
||||
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
|
||||
|
@ -2678,19 +2676,15 @@ static void llm_load_vocab(
|
|||
}
|
||||
|
||||
// Handle add_bos_token and add_eos_token
|
||||
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
|
||||
int kid = gguf_find_key(ctx, key.c_str());
|
||||
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
|
||||
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
|
||||
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
|
||||
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
|
||||
{
|
||||
bool temp = true;
|
||||
|
||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
||||
vocab.special_add_bos = int(temp);
|
||||
}
|
||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
||||
vocab.special_add_eos = int(temp);
|
||||
}
|
||||
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
|
||||
kid = gguf_find_key(ctx, key.c_str());
|
||||
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
|
||||
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
|
||||
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
|
||||
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue