From 414fc13248c446331f541b7099aff26290ef26fc Mon Sep 17 00:00:00 2001 From: mare5x Date: Sat, 29 Jun 2024 13:42:00 +0200 Subject: [PATCH] token healing : refactor to return struct --- common/sampling.cpp | 51 +++++++++++++++++------------------------- common/sampling.h | 20 ++++++++++------- examples/main/main.cpp | 25 ++++++++++----------- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b407df45c..bdcdde057 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -49,18 +49,13 @@ static size_t get_max_token_length(const llama_context * ctx_main) { return len; } -struct token_healing_info { - std::string prefix; - int n_tokens_removed; -}; - -token_healing_info llama_token_healing_get_prefix( - const llama_context * ctx_main, - const llama_token_healing_type th_type, - const std::vector & tokens, - int max_to_remove) { +static llama_token_healing_output llama_token_healing_get_prefix( + const llama_context * ctx_main, + const llama_token_healing_type th_type, + const std::vector & tokens, + int max_to_remove) { if (tokens.size() <= 1) { - return {"", 0}; + return {}; } const int n_ctx = tokens.size(); @@ -122,34 +117,28 @@ token_healing_info llama_token_healing_get_prefix( // Token healing (external) // -std::string llama_token_healing_rollback( - const llama_context * ctx_main, - llama_token_healing_type th_type, - std::vector & tokens, - int max_to_remove, - int * n_removed) { - if (n_removed != nullptr) { - *n_removed = 0; - } +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove) { // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. - token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); + llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); // If constrained decoding would give back the original prompt, there is no need to modify the prompt. const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const std::vector candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed); - if (info.n_tokens_removed == 1 && candidates.size() == 1) { + const std::vector candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed); + if (out.n_tokens_removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); - return ""; + return {}; } - // Finalize outputs - if (n_removed != nullptr) { - *n_removed = info.n_tokens_removed; - } - tokens.resize(tokens.size() - info.n_tokens_removed); - return info.prefix; + + // Finally, trim prompt tokens + tokens.resize(tokens.size() - out.n_tokens_removed); + return out; } void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { diff --git a/common/sampling.h b/common/sampling.h index 4c1172985..094b40c89 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -176,13 +176,17 @@ void llama_sampling_accept( // Token healing // -// Roll back `tokens` for constrained generation according to the token healing -// strategy. Returns the prefix for constrained generation. -std::string llama_token_healing_rollback( - const llama_context * ctx_main, - llama_token_healing_type th_type, - std::vector & tokens, - int max_to_remove = -1, - int * n_removed = nullptr); +struct llama_token_healing_output { + std::string prefix; + int n_tokens_removed; +}; + +// Roll back `tokens` for constrained generation according to the token healing strategy. +// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling. +llama_token_healing_output llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove = -1); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b08fec7dc..6976b2697 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -295,11 +295,10 @@ int main(int argc, char ** argv) { sparams.token_healing_enabled = false; LOG("token healing: disabled due to custom suffix/conversation mode"); } - std::string token_healing_prefix; - int token_healing_n_removed = 0; + llama_token_healing_output token_healing_out{}; if (!params.interactive_first && sparams.token_healing_enabled) { - token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback, &token_healing_n_removed); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback); } // Should not run without any tokens @@ -326,7 +325,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size() - token_healing_n_removed; + original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -548,7 +547,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } - llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -883,7 +882,8 @@ int main(int argc, char ** argv) { assistant_ss << llama_token_to_piece(ctx, id, false); } - token_healing_n_removed = 0; + token_healing_out = {}; + if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -962,9 +962,8 @@ int main(int argc, char ** argv) { const int max_to_remove = sparams.token_healing_n_rollback < 0 ? n_new_tokens : std::min(sparams.token_healing_n_rollback, n_new_tokens); - token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, - max_to_remove, &token_healing_n_removed); - n_bytes_to_skip = token_healing_prefix.size(); + token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove); + n_bytes_to_skip = token_healing_out.prefix.size(); } for (size_t i = original_size; i < embd_inp.size(); ++i) { @@ -976,7 +975,7 @@ int main(int argc, char ** argv) { // reset assistant message assistant_ss.str(""); - n_remain -= line_inp.size() + token_healing_n_removed; + n_remain -= line_inp.size() + token_healing_out.n_tokens_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -988,9 +987,9 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); - if (token_healing_n_removed > 0) { + if (token_healing_out.n_tokens_removed > 0) { // Set new prefix after an interaction - llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix); } } is_interacting = false;