From 951b6593b2ec2ebf27cf4dd03c81f8a35326c358 Mon Sep 17 00:00:00 2001 From: mare5x Date: Fri, 3 May 2024 13:50:31 +0200 Subject: [PATCH] main : first attempt at token healing in `main` --- common/common.cpp | 25 ++++ common/sampling.cpp | 136 +++++++++++++++++- common/sampling.h | 23 +++ examples/main/main.cpp | 7 + .../simple-token-healing.cpp | 33 ++--- 5 files changed, 200 insertions(+), 24 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 243b88abf..7f1d13605 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.token_healing_enabled = true; + auto & th_type = sparams.token_healing_type; + auto & th_n_rollback = sparams.token_healing_n_rollback; + std::string value(argv[i]); + /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + sparams.token_healing_enabled = false; + } + } else { invalid_param = true; } + return true; + } if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; @@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); + printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n"); + printf(" Token healing type. (default: 0, disabled)\n"); + printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n"); printf(" --cfg-negative-prompt PROMPT\n"); printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); diff --git a/common/sampling.cpp b/common/sampling.cpp index cc83600d9..5549369e8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,96 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +// +// Token healing (external) +// + +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int n_rollback) { + if (tokens.empty()) { + return ""; + } + const llama_model * model = llama_get_model(ctx_main); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens.size(); + const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); + int n_removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered + while (n_removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { + break; + } + n_removed += 1; + prefix = new_prefix; + } + + if (n_removed == 0) { // E.g. if the last token is a special token + return ""; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); + if (n_removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return ""; + } + tokens.resize(n_ctx - n_removed); + return prefix; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +123,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } + result->token_healing_prefix.clear(); + result->prev.resize(params.n_prev); llama_sampling_set_rng_seed(result, params.seed); @@ -62,6 +154,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); } @@ -119,7 +213,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = sampler_type_to_name_string(sampler_type); @@ -297,12 +391,33 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_type = params.token_healing_type; + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing_enabled && !th_prefix.empty()) { + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } + + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + for (const llama_token token_id: th_candidates) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + // TODO should we skip penalties and grammar while token healing? + // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); @@ -361,4 +476,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } + + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index cf7081e36..e2b870f00 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,13 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +69,10 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; + bool token_healing_enabled = false; + int token_healing_n_rollback = 1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -78,6 +89,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -152,3 +165,13 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int n_rollback = 1); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5c693657c..c9e6d2de9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,12 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + std::string token_healing_prefix; + if (sparams.token_healing_enabled) { + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback); + } + // Should not run without any tokens if (embd_inp.empty()) { embd_inp.push_back(llama_token_bos(model)); @@ -520,6 +526,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + ctx_sampling->token_healing_prefix = token_healing_prefix; while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp index 79b1693ad..05091b9c3 100644 --- a/examples/simple-token-healing/simple-token-healing.cpp +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -8,13 +8,6 @@ #define TH_VERBOSE // print token healing candidates -enum class token_healing_type : uint8_t { - ROLLBACK_LAST, // roll back last token with a single constrained decoding step - ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps - DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step - DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps -}; - struct token_healing_context { std::string prefix; // remaining prefix to generate (the input prompt's suffix) @@ -44,8 +37,8 @@ static std::vector token_healing_find_prefix( std::vector candidates; const auto & vocab = th_ctx->vocab; for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { - if (startswith(vocab[token_id], prefix) - || (include_partial_prefix && startswith(prefix, vocab[token_id]))) { + if (startswith(vocab[token_id], prefix) || + (include_partial_prefix && startswith(prefix, vocab[token_id]))) { candidates.push_back((llama_token)token_id); } } @@ -71,14 +64,14 @@ static void token_healing_free(token_healing_context * th_ctx) { static int token_healing_heal( const llama_context * ctx, std::vector & tokens_list, - const token_healing_type th_type, + const llama_token_healing_type th_type, token_healing_context * th_ctx, int n_rollback = 1) { if (tokens_list.empty()) { return 0; } const llama_model * model = llama_get_model(ctx); - const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI; + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens_list.size(); const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); int n_removed = 0; @@ -104,7 +97,7 @@ static int token_healing_heal( return 0; } // If constrained decoding would give back the original prompt, there is no need to modify the context - const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI; + const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI; const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); if (n_removed == 1 && candidates.size() == 1) { @@ -119,9 +112,7 @@ static int token_healing_heal( } } #endif - for (int i = 0; i < n_removed; ++i) { - tokens_list.pop_back(); - } + tokens_list.resize(n_ctx - n_removed); if (tokens_list.empty()) { // If the first token was removed, llama_decode would crash with an empty sequence, so add bos. tokens_list.emplace_back(llama_token_bos(model)); @@ -146,16 +137,16 @@ int main(int argc, char ** argv) { } bool token_healing_enabled = true; - auto th_type = token_healing_type::DYNAMIC_MULTI; + auto th_type = llama_token_healing_type::DYNAMIC_MULTI; int th_n_rollback = 1; if (argc >= 4) { std::string value(argv[3]); /**/ if (value == "0" ) { token_healing_enabled = false; } - else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } - else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; } - else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } else if (value[0] == 'r' ) { - th_type = token_healing_type::ROLLBACK_MULTI; + th_type = llama_token_healing_type::ROLLBACK_MULTI; th_n_rollback = std::stoi(value.substr(1)); if (th_n_rollback <= 0) { token_healing_enabled = false; @@ -281,7 +272,7 @@ int main(int argc, char ** argv) { // Constrain tokens based on the remaining token healing prefix // N.B. We could also set token constraints by setting rejected tokens' logits to -inf std::vector th_candidates; - if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) { + if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); } else { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);