diff --git a/examples/simple-token-healing/README.md b/examples/simple-token-healing/README.md index 7e5469866..533c118bd 100644 --- a/examples/simple-token-healing/README.md +++ b/examples/simple-token-healing/README.md @@ -1,10 +1,13 @@ # llama.cpp/example/simple-token-healing -This example extends [simple](../simple/README.md) with [token healing](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb). +This example extends [simple](../simple/README.md) with token healing (aka. token alignment). -Without token healing: +`usage: ./simple-token-healing MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]` + +## Examples +`0`: Without token healing (same as running `./simple ...`): ```bash -./simple ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 0 ... main: n_len = 32, n_ctx = 2048, n_kv_req = 32 @@ -12,7 +15,7 @@ print('Helping the customer') ... ``` -Heal the last token (`1`): +`1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]: ```bash ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1 ... @@ -29,9 +32,9 @@ print('Hello, World!') ... ``` -Backtrack multiple tokens until there doesn't exist a token which can cover the prompt's suffix (`n`): +`d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]: ```bash -./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" n +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d1 ... token_healing: prefix = ' worl' (2 tokens) [ 995] ' world' @@ -46,9 +49,9 @@ print('Hello, world!') ... ``` -Backtrack multiple tokens but don't constrain the decoding to a single token (`m`): +`d`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix but allow multiple decoding steps: ```bash -./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" m +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d ... token_healing: prefix = ' worl' (2 tokens) @@ -68,3 +71,35 @@ token_healing: prefix = ' worl' world!') ... ``` + +`r[N]`: Roll back `N` tokens and constrain the decoding to the bytes of those tokens (multiple decoding steps) [1]. +The paper [1] recommends `N=3`: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" r3 +... +token_healing: prefix = ', worl' (3 tokens) + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello +token_healing: prefix = ', worl' + [ 11] ',' +, +token_healing: prefix = ' worl' + [ 220] ' ' + [ 266] ' w' + [ 476] ' wor' + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 24486] ' wo' + [ 29081] ' worldview' + [ 43249] ' worldly' + world!') +... +``` + +## Sources +- [0] https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb +- [1] https://arxiv.org/abs/2403.08688 +- [2] https://arxiv.org/abs/2402.01035 diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp index 48f736a0e..79b1693ad 100644 --- a/examples/simple-token-healing/simple-token-healing.cpp +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -9,19 +9,20 @@ #define TH_VERBOSE // print token healing candidates enum class token_healing_type : uint8_t { - LAST, // replace last token only - MULTI_ONCE, // replace multiple last tokens with a single token - MULTI // replace multiple last tokens with multiple decoding steps + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps }; struct token_healing_context { std::string prefix; // remaining prefix to generate (the input prompt's suffix) - std::vector vocab; // map token id to token piece + std::vector vocab; // map token id to token piece // TODO consider using a prefix tree }; -static inline bool startswith(const std::string & str, const std::string & prefix) { +static bool startswith(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) != std::string::npos; } @@ -67,28 +68,31 @@ static void token_healing_free(token_healing_context * th_ctx) { delete th_ctx; } -static int token_healing_start( +static int token_healing_heal( const llama_context * ctx, std::vector & tokens_list, const token_healing_type th_type, - token_healing_context * th_ctx) { + token_healing_context * th_ctx, + int n_rollback = 1) { if (tokens_list.empty()) { return 0; } const llama_model * model = llama_get_model(ctx); + const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI; const int n_ctx = tokens_list.size(); - const int max_to_remove = (th_type == token_healing_type::LAST) ? 1 : n_ctx; + const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); int n_removed = 0; std::string prefix; - // Backtrack tokens until there does not exist a token that can cover the prompt + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered while (n_removed < max_to_remove) { - const llama_token next_token = tokens_list[n_ctx - n_removed - 1]; - if (llama_token_get_type(model, next_token) != LLAMA_TOKEN_TYPE_NORMAL) { + const llama_token next_token_id = tokens_list[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) break; } - std::string new_prefix = llama_token_to_piece(ctx, next_token) + prefix; - if (!token_healing_prefix_exists(th_ctx, new_prefix)) { + std::string new_prefix = th_ctx->vocab[next_token_id] + prefix; + if (is_dynamic && !token_healing_prefix_exists(th_ctx, new_prefix)) { break; } n_removed += 1; @@ -99,14 +103,17 @@ static int token_healing_start( if (n_removed == 0) { return 0; } - const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, false); + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI; + const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); if (n_removed == 1 && candidates.size() == 1) { fprintf(stderr, "token_healing: nothing to heal\n"); return 0; } #ifdef TH_VERBOSE - if (th_type != token_healing_type::MULTI) { + if (!is_multi_decoding) { + // Other healing types get printed during decoding for (const llama_token token_id : candidates) { fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); } @@ -126,8 +133,8 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); - return 1 ; + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); + return 1; } if (argc >= 2) { @@ -139,15 +146,22 @@ int main(int argc, char ** argv) { } bool token_healing_enabled = true; - auto th_type = token_healing_type::LAST; + auto th_type = token_healing_type::DYNAMIC_MULTI; + int th_n_rollback = 1; if (argc >= 4) { std::string value(argv[3]); - /**/ if (value == "0") { token_healing_enabled = false; } - else if (value == "1") { th_type = token_healing_type::LAST; } - else if (value == "n") { th_type = token_healing_type::MULTI_ONCE; } - else if (value == "m") { th_type = token_healing_type::MULTI; } - else { - printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); + /**/ if (value == "0" ) { token_healing_enabled = false; } + else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + token_healing_enabled = false; + } + } else { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); return 1; } } @@ -201,7 +215,7 @@ int main(int argc, char ** argv) { token_healing_context * th_ctx = nullptr; if (token_healing_enabled) { th_ctx = token_healing_init(ctx); - int th_n_tokens_removed = token_healing_start(ctx, tokens_list, th_type, th_ctx); + int th_n_tokens_removed = token_healing_heal(ctx, tokens_list, th_type, th_ctx, th_n_rollback); if (th_n_tokens_removed == 0) { token_healing_enabled = false; } @@ -267,7 +281,7 @@ int main(int argc, char ** argv) { // Constrain tokens based on the remaining token healing prefix // N.B. We could also set token constraints by setting rejected tokens' logits to -inf std::vector th_candidates; - if (th_type == token_healing_type::LAST || th_type == token_healing_type::MULTI_ONCE) { + if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); } else { th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);