From ea4abc9d8255d27dbcb3d87473579781d04a59d8 Mon Sep 17 00:00:00 2001 From: mare5x Date: Mon, 1 Jul 2024 11:51:39 +0200 Subject: [PATCH] token healing : refactor argument parsing Unify `main` and `server` token healing argument handling. --- common/common.cpp | 15 +----------- common/sampling.cpp | 28 +++++++++++++++++++---- common/sampling.h | 13 ++++++++--- examples/main/main.cpp | 18 +++++++-------- examples/server/server.cpp | 47 +++++++++++++++----------------------- 5 files changed, 61 insertions(+), 60 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e5eccc54e..141abaef8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "-th" || arg == "--token-healing") { CHECK_ARG - sparams.token_healing_enabled = true; - auto & th_type = sparams.token_healing_type; - auto & th_n_rollback = sparams.token_healing_n_rollback; std::string value(argv[i]); - /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } - else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } - else if (value[0] == 'r' ) { - th_type = llama_token_healing_type::ROLLBACK_MULTI; - th_n_rollback = std::stoi(value.substr(1)); - if (th_n_rollback <= 0) { - sparams.token_healing_enabled = false; - } - } else { invalid_param = true; } + invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing); return true; } if (arg == "--override-kv") { diff --git a/common/sampling.cpp b/common/sampling.cpp index b5c6b9ad3..2d1610b39 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const ctx_sampling->token_healing_prefix = prefix; } +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) { + th_params.enabled = true; + th_params.n_rollback = -1; + /**/ if (params == "0" ) { th_params.enabled = false; } + else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; } + else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (params[0] == 'r' ) { + th_params.type = llama_token_healing_type::ROLLBACK_MULTI; + th_params.n_rollback = std::stoi(params.substr(1)); + if (th_params.n_rollback <= 0) { + return false; + } + } else { + return false; + } + return true; +} + // // Sampling // @@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.resize(n_vocab); // Constrain tokens based on the remaining token healing prefix (if any) - const auto & th_type = params.token_healing_type; const auto & th_prefix = ctx_sampling->token_healing_prefix; - if (params.token_healing_enabled && !th_prefix.empty()) { - const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || - th_type == llama_token_healing_type::DYNAMIC_MULTI; + if (params.token_healing.enabled && !th_prefix.empty()) { + const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI || + params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI; std::vector th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); @@ -635,7 +653,7 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } - if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + if (ctx_sampling->params.token_healing.enabled && apply_grammar) { std::string & th_prefix = ctx_sampling->token_healing_prefix; if (!th_prefix.empty()) { const std::string new_token_piece = llama_token_to_piece(ctx_main, id); diff --git a/common/sampling.h b/common/sampling.h index 094b40c89..a269ab11e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t { DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps }; +struct llama_token_healing_params { + bool enabled = false; + llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI; + int n_rollback = -1; // number of tokens to roll back +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -70,9 +76,7 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; - llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; - bool token_healing_enabled = false; - int token_healing_n_rollback = -1; // number of tokens to roll back + llama_token_healing_params token_healing; } llama_sampling_params; // general sampler context @@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback( int max_to_remove = -1); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); + +// Helper for parsing token healing params from a string. +bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6976b2697..e8a0eefb9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -291,14 +291,14 @@ int main(int argc, char ** argv) { LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); } - if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { - sparams.token_healing_enabled = false; + if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) { + sparams.token_healing.enabled = false; LOG("token healing: disabled due to custom suffix/conversation mode"); } llama_token_healing_output token_healing_out{}; - if (!params.interactive_first && sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback); + if (!params.interactive_first && sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, + sparams.token_healing.n_rollback); } // Should not run without any tokens @@ -956,13 +956,13 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); - if (sparams.token_healing_enabled) { + if (sparams.token_healing.enabled) { // Limit token healing rollback to new tokens only (otherwise would need to shift everything) const int n_new_tokens = embd_inp.size() - original_size; - const int max_to_remove = sparams.token_healing_n_rollback < 0 + const int max_to_remove = sparams.token_healing.n_rollback < 0 ? n_new_tokens - : std::min(sparams.token_healing_n_rollback, n_new_tokens); - token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove); + : std::min(sparams.token_healing.n_rollback, n_new_tokens); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove); n_bytes_to_skip = token_healing_out.prefix.size(); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 01aac8ed5..ef2d7fa21 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1098,31 +1098,20 @@ struct server_context { { const auto & token_healing_str = data.find("token_healing"); - auto & th_enabled = slot.sparams.token_healing_enabled; - th_enabled = default_sparams.token_healing_enabled; if (token_healing_str != data.end() && token_healing_str->is_string()) { const auto value = token_healing_str->get(); - auto & th_type = slot.sparams.token_healing_type; - auto & th_n_rollback = slot.sparams.token_healing_n_rollback; - th_enabled = true; - /**/ if (value == "0" ) { th_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } - else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } - else if (value[0] == 'r' ) { - th_type = llama_token_healing_type::ROLLBACK_MULTI; - th_n_rollback = std::stoi(value.substr(1)); - if (th_n_rollback <= 0) { - th_enabled = false; - } - } else { th_enabled = false; } - + if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) { + send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST); + return false; + } LOG_VERBOSE("token healing", { {"id_slot", slot.id}, - {"enabled", th_enabled}, - {"type", th_type}, - {"n_rollback", th_n_rollback} + {"enabled", slot.sparams.token_healing.enabled}, + {"type", slot.sparams.token_healing.type}, + {"n_rollback", slot.sparams.token_healing.n_rollback} }); + } else { + slot.sparams.token_healing = default_sparams.token_healing; } } @@ -1406,7 +1395,7 @@ struct server_context { {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, {"samplers", samplers_sequence}, - {"token_healing_enabled", slot.sparams.token_healing_enabled} + {"token_healing_enabled", slot.sparams.token_healing.enabled} }; } @@ -2109,10 +2098,10 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { // For FIM roll back only the prefix part (i.e. cursor location) - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prefix_tokens, slot.sparams.token_healing_n_rollback); + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, + prefix_tokens, slot.sparams.token_healing.n_rollback); } auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; @@ -2131,9 +2120,9 @@ struct server_context { } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt - if (slot.sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prompt_tokens, slot.sparams.token_healing_n_rollback); + if (slot.sparams.token_healing.enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type, + prompt_tokens, slot.sparams.token_healing.n_rollback); } } @@ -2149,7 +2138,7 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { slot.n_th_prefix = token_healing_out.prefix.size(); LOG_VERBOSE("token healing prompt", { {"id_slot", slot.id}, @@ -2224,7 +2213,7 @@ struct server_context { } llama_sampling_reset(slot.ctx_sampling); - if (slot.sparams.token_healing_enabled) { + if (slot.sparams.token_healing.enabled) { llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); }