From db9c018891772624d9da0a92816ce68920c67c50 Mon Sep 17 00:00:00 2001 From: mare5x Date: Sat, 29 Jun 2024 13:02:30 +0200 Subject: [PATCH] token healing : change dynamic rollback Dynamic rollback now starts checking prefixes based on the length of the longest token. --- common/sampling.cpp | 136 ++++++++++++++++++++++++++++------------- examples/main/main.cpp | 2 +- 2 files changed, 95 insertions(+), 43 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 02795a182..b407df45c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) { static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { - if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix)) { return true; } } return false; } -static std::vector token_healing_find_prefix( +static std::vector token_healing_get_candidates( const llama_context * ctx_main, const std::string & prefix, const bool include_partial_prefix) { @@ -38,6 +39,85 @@ static std::vector token_healing_find_prefix( return candidates; } +static size_t get_max_token_length(const llama_context * ctx_main) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + size_t len = 0; + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + len = std::max(len, token.size()); + } + return len; +} + +struct token_healing_info { + std::string prefix; + int n_tokens_removed; +}; + +token_healing_info llama_token_healing_get_prefix( + const llama_context * ctx_main, + const llama_token_healing_type th_type, + const std::vector & tokens, + int max_to_remove) { + if (tokens.size() <= 1) { + return {"", 0}; + } + + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain + + int removed = 0; + std::string prefix; + + const llama_model * model = llama_get_model(ctx_main); + auto is_special_token = [&](const llama_token token_id) { + return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id); + }; + + if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) { + // The number of bytes to roll back cannot exceed the length of the longest token. + const size_t n_longest_token = get_max_token_length(ctx_main); + size_t len = 0; + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size(); + if (len + next_token_size > n_longest_token) { + break; + } + len += next_token_size; + removed += 1; + } + + while (removed > 0) { + prefix.clear(); + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + if (token_healing_prefix_exists(ctx_main, prefix)) { + break; // Stop on longest valid prefix + } + removed -= 1; + } + } else { + // Roll back tokens a fixed amount and stop early if a special token is encountered. + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (is_special_token(next_token_id)) { + break; + } + removed += 1; + } + for (int i = n_ctx - removed; i < n_ctx; i++) { + prefix += llama_token_to_piece(ctx_main, tokens[i]); + } + } + return {prefix, removed}; +} + // // Token healing (external) // @@ -48,56 +128,28 @@ std::string llama_token_healing_rollback( std::vector & tokens, int max_to_remove, int * n_removed) { - // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. - // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. if (n_removed != nullptr) { *n_removed = 0; } - if (tokens.size() <= 1) { - return ""; - } - const llama_model * model = llama_get_model(ctx_main); - const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const int n_ctx = tokens.size(); - max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; - max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain - int removed = 0; - std::string prefix; - // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt - // and stop early if a special token is encountered. - // NB. This doesn't handle cases where a long token is split many times, - // e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically), - // then "abc" will not be returned even if "abcd" exists in the vocab. - while (removed < max_to_remove) { - const llama_token next_token_id = tokens[n_ctx - removed - 1]; - if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) { - break; // Don't roll back e.g. <|endoftext|> - } - std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; - if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { - break; - } - removed += 1; - prefix = new_prefix; - } - if (removed == 0) { // E.g. if the last token is a special token - return ""; - } - // If constrained decoding would give back the original prompt, there is no need to modify the context + // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. + // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. + token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove); + + // If constrained decoding would give back the original prompt, there is no need to modify the prompt. const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); - LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); - if (removed == 1 && candidates.size() == 1) { + const std::vector candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed); + if (info.n_tokens_removed == 1 && candidates.size() == 1) { LOG("token_healing: nothing to heal\n"); return ""; } // Finalize outputs if (n_removed != nullptr) { - *n_removed = removed; + *n_removed = info.n_tokens_removed; } - tokens.resize(n_ctx - removed); - return prefix; + tokens.resize(tokens.size() - info.n_tokens_removed); + return info.prefix; } void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { @@ -507,7 +559,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (params.token_healing_enabled && !th_prefix.empty()) { const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || th_type == llama_token_healing_type::DYNAMIC_MULTI; - std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + std::vector th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); for (const llama_token token_id : th_candidates) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b3e47b36b..b08fec7dc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -293,7 +293,7 @@ int main(int argc, char ** argv) { if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { sparams.token_healing_enabled = false; - LOG("token_healing: disabled due to custom suffix/conversation mode"); + LOG("token healing: disabled due to custom suffix/conversation mode"); } std::string token_healing_prefix; int token_healing_n_removed = 0;