From 7d0cc78bc32725fc4ffb29bf854754575990039a Mon Sep 17 00:00:00 2001 From: mare5x Date: Fri, 3 May 2024 19:50:00 +0200 Subject: [PATCH] main : better token healing support for interactive mode --- common/common.cpp | 2 +- common/sampling.cpp | 31 ++++++++++++++++++++----------- common/sampling.h | 5 +++-- examples/main/main.cpp | 23 +++++++++++++++++++++-- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 7f1d13605..b75cfdf95 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1298,7 +1298,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa auto & th_n_rollback = sparams.token_healing_n_rollback; std::string value(argv[i]); /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } - else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } else if (value[0] == 'r' ) { diff --git a/common/sampling.cpp b/common/sampling.cpp index 5549369e8..03c2664bb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -46,20 +46,26 @@ std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, - int n_rollback) { + int max_to_remove, + int * n_removed) { + if (n_removed != nullptr) { + *n_removed = 0; + } if (tokens.empty()) { return ""; } + const llama_model * model = llama_get_model(ctx_main); const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens.size(); - const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); - int n_removed = 0; + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx); + int removed = 0; std::string prefix; // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt // and stop early if a special token is encountered - while (n_removed < max_to_remove) { - const llama_token next_token_id = tokens[n_ctx - n_removed - 1]; + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) break; @@ -68,23 +74,26 @@ std::string llama_token_healing_prepare( if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { break; } - n_removed += 1; + removed += 1; prefix = new_prefix; } - - if (n_removed == 0) { // E.g. if the last token is a special token + if (removed == 0) { // E.g. if the last token is a special token return ""; } // If constrained decoding would give back the original prompt, there is no need to modify the context const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); - if (n_removed == 1 && candidates.size() == 1) { + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); + if (removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); return ""; } - tokens.resize(n_ctx - n_removed); + // Finalize outputs + if (n_removed != nullptr) { + *n_removed = removed; + } + tokens.resize(n_ctx - removed); return prefix; } diff --git a/common/sampling.h b/common/sampling.h index e2b870f00..90198bec9 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -72,7 +72,7 @@ typedef struct llama_sampling_params { llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; bool token_healing_enabled = false; - int token_healing_n_rollback = 1; // number of tokens to roll back + int token_healing_n_rollback = -1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -174,4 +174,5 @@ std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, - int n_rollback = 1); + int max_to_remove = -1, + int * n_removed = nullptr); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c9e6d2de9..aedc40334 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,8 +264,12 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (sparams.token_healing_enabled && (params.instruct || params.chatml || !params.input_suffix.empty())) { + sparams.token_healing_enabled = false; + LOG("token_healing: disabled due to custom suffix"); + } std::string token_healing_prefix; - if (sparams.token_healing_enabled) { + if (!params.interactive_first && sparams.token_healing_enabled) { token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, sparams.token_healing_n_rollback); } @@ -820,6 +824,7 @@ int main(int argc, char ** argv) { } } + int token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -903,13 +908,23 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); } + if (sparams.token_healing_enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing_n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing_n_rollback, n_new_tokens); + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + max_to_remove, &token_healing_n_removed); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); output_ss << llama_token_to_piece(ctx, token); } - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_n_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -921,6 +936,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_n_removed > 0) { + // Set new prefix after an interaction + ctx_sampling->token_healing_prefix = token_healing_prefix; + } } is_interacting = false; }