examples : more roll back options for token healing

This commit is contained in:
mare5x 2024-04-30 20:04:35 +02:00
parent c77bb3203c
commit 88ef908c90
2 changed files with 83 additions and 34 deletions

View file

@ -1,10 +1,13 @@
# llama.cpp/example/simple-token-healing # llama.cpp/example/simple-token-healing
This example extends [simple](../simple/README.md) with [token healing](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb). This example extends [simple](../simple/README.md) with token healing (aka. token alignment).
Without token healing: `usage: ./simple-token-healing MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]`
## Examples
`0`: Without token healing (same as running `./simple ...`):
```bash ```bash
./simple ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 0
... ...
main: n_len = 32, n_ctx = 2048, n_kv_req = 32 main: n_len = 32, n_ctx = 2048, n_kv_req = 32
@ -12,7 +15,7 @@ print('Helping the customer')
... ...
``` ```
Heal the last token (`1`): `1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]:
```bash ```bash
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1 ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1
... ...
@ -29,9 +32,9 @@ print('Hello, World!')
... ...
``` ```
Backtrack multiple tokens until there doesn't exist a token which can cover the prompt's suffix (`n`): `d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]:
```bash ```bash
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" n ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d1
... ...
token_healing: prefix = ' worl' (2 tokens) token_healing: prefix = ' worl' (2 tokens)
[ 995] ' world' [ 995] ' world'
@ -46,9 +49,9 @@ print('Hello, world!')
... ...
``` ```
Backtrack multiple tokens but don't constrain the decoding to a single token (`m`): `d`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix but allow multiple decoding steps:
```bash ```bash
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" m ./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d
... ...
token_healing: prefix = ' worl' (2 tokens) token_healing: prefix = ' worl' (2 tokens)
@ -68,3 +71,35 @@ token_healing: prefix = ' worl'
world!') world!')
... ...
``` ```
`r[N]`: Roll back `N` tokens and constrain the decoding to the bytes of those tokens (multiple decoding steps) [1].
The paper [1] recommends `N=3`:
```bash
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" r3
...
token_healing: prefix = ', worl' (3 tokens)
main: n_len = 32, n_ctx = 2048, n_kv_req = 32
print('Hello
token_healing: prefix = ', worl'
[ 11] ','
,
token_healing: prefix = ' worl'
[ 220] ' '
[ 266] ' w'
[ 476] ' wor'
[ 995] ' world'
[ 8688] ' worldwide'
[ 11621] ' worlds'
[ 24486] ' wo'
[ 29081] ' worldview'
[ 43249] ' worldly'
world!')
...
```
## Sources
- [0] https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb
- [1] https://arxiv.org/abs/2403.08688
- [2] https://arxiv.org/abs/2402.01035

View file

@ -9,9 +9,10 @@
#define TH_VERBOSE // print token healing candidates #define TH_VERBOSE // print token healing candidates
enum class token_healing_type : uint8_t { enum class token_healing_type : uint8_t {
LAST, // replace last token only ROLLBACK_LAST, // roll back last token with a single constrained decoding step
MULTI_ONCE, // replace multiple last tokens with a single token ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
MULTI // replace multiple last tokens with multiple decoding steps DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
}; };
struct token_healing_context { struct token_healing_context {
@ -21,7 +22,7 @@ struct token_healing_context {
// TODO consider using a prefix tree // TODO consider using a prefix tree
}; };
static inline bool startswith(const std::string & str, const std::string & prefix) { static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos; return str.rfind(prefix, 0) != std::string::npos;
} }
@ -67,28 +68,31 @@ static void token_healing_free(token_healing_context * th_ctx) {
delete th_ctx; delete th_ctx;
} }
static int token_healing_start( static int token_healing_heal(
const llama_context * ctx, const llama_context * ctx,
std::vector<llama_token> & tokens_list, std::vector<llama_token> & tokens_list,
const token_healing_type th_type, const token_healing_type th_type,
token_healing_context * th_ctx) { token_healing_context * th_ctx,
int n_rollback = 1) {
if (tokens_list.empty()) { if (tokens_list.empty()) {
return 0; return 0;
} }
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens_list.size(); const int n_ctx = tokens_list.size();
const int max_to_remove = (th_type == token_healing_type::LAST) ? 1 : n_ctx; const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
int n_removed = 0; int n_removed = 0;
std::string prefix; std::string prefix;
// Backtrack tokens until there does not exist a token that can cover the prompt // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered
while (n_removed < max_to_remove) { while (n_removed < max_to_remove) {
const llama_token next_token = tokens_list[n_ctx - n_removed - 1]; const llama_token next_token_id = tokens_list[n_ctx - n_removed - 1];
if (llama_token_get_type(model, next_token) != LLAMA_TOKEN_TYPE_NORMAL) { if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
break; break;
} }
std::string new_prefix = llama_token_to_piece(ctx, next_token) + prefix; std::string new_prefix = th_ctx->vocab[next_token_id] + prefix;
if (!token_healing_prefix_exists(th_ctx, new_prefix)) { if (is_dynamic && !token_healing_prefix_exists(th_ctx, new_prefix)) {
break; break;
} }
n_removed += 1; n_removed += 1;
@ -99,14 +103,17 @@ static int token_healing_start(
if (n_removed == 0) { if (n_removed == 0) {
return 0; return 0;
} }
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, false); // If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
if (n_removed == 1 && candidates.size() == 1) { if (n_removed == 1 && candidates.size() == 1) {
fprintf(stderr, "token_healing: nothing to heal\n"); fprintf(stderr, "token_healing: nothing to heal\n");
return 0; return 0;
} }
#ifdef TH_VERBOSE #ifdef TH_VERBOSE
if (th_type != token_healing_type::MULTI) { if (!is_multi_decoding) {
// Other healing types get printed during decoding
for (const llama_token token_id : candidates) { for (const llama_token token_id : candidates) {
fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str());
} }
@ -126,8 +133,8 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]);
return 1 ; return 1;
} }
if (argc >= 2) { if (argc >= 2) {
@ -139,15 +146,22 @@ int main(int argc, char ** argv) {
} }
bool token_healing_enabled = true; bool token_healing_enabled = true;
auto th_type = token_healing_type::LAST; auto th_type = token_healing_type::DYNAMIC_MULTI;
int th_n_rollback = 1;
if (argc >= 4) { if (argc >= 4) {
std::string value(argv[3]); std::string value(argv[3]);
/**/ if (value == "0") { token_healing_enabled = false; } /**/ if (value == "0" ) { token_healing_enabled = false; }
else if (value == "1") { th_type = token_healing_type::LAST; } else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
else if (value == "n") { th_type = token_healing_type::MULTI_ONCE; } else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; }
else if (value == "m") { th_type = token_healing_type::MULTI; } else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; }
else { else if (value[0] == 'r' ) {
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]); th_type = token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
token_healing_enabled = false;
}
} else {
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]);
return 1; return 1;
} }
} }
@ -201,7 +215,7 @@ int main(int argc, char ** argv) {
token_healing_context * th_ctx = nullptr; token_healing_context * th_ctx = nullptr;
if (token_healing_enabled) { if (token_healing_enabled) {
th_ctx = token_healing_init(ctx); th_ctx = token_healing_init(ctx);
int th_n_tokens_removed = token_healing_start(ctx, tokens_list, th_type, th_ctx); int th_n_tokens_removed = token_healing_heal(ctx, tokens_list, th_type, th_ctx, th_n_rollback);
if (th_n_tokens_removed == 0) { if (th_n_tokens_removed == 0) {
token_healing_enabled = false; token_healing_enabled = false;
} }
@ -267,7 +281,7 @@ int main(int argc, char ** argv) {
// Constrain tokens based on the remaining token healing prefix // Constrain tokens based on the remaining token healing prefix
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf // N.B. We could also set token constraints by setting rejected tokens' logits to -inf
std::vector<llama_token> th_candidates; std::vector<llama_token> th_candidates;
if (th_type == token_healing_type::LAST || th_type == token_healing_type::MULTI_ONCE) { if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) {
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
} else { } else {
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);