examples : more roll back options for token healing
This commit is contained in:
parent
c77bb3203c
commit
88ef908c90
2 changed files with 83 additions and 34 deletions
|
@ -1,10 +1,13 @@
|
|||
# llama.cpp/example/simple-token-healing
|
||||
|
||||
This example extends [simple](../simple/README.md) with [token healing](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb).
|
||||
This example extends [simple](../simple/README.md) with token healing (aka. token alignment).
|
||||
|
||||
Without token healing:
|
||||
`usage: ./simple-token-healing MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]`
|
||||
|
||||
## Examples
|
||||
`0`: Without token healing (same as running `./simple ...`):
|
||||
```bash
|
||||
./simple ./models/phi-2/ggml-model-q4_0.gguf "print('Hel"
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 0
|
||||
...
|
||||
main: n_len = 32, n_ctx = 2048, n_kv_req = 32
|
||||
|
||||
|
@ -12,7 +15,7 @@ print('Helping the customer')
|
|||
...
|
||||
```
|
||||
|
||||
Heal the last token (`1`):
|
||||
`1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]:
|
||||
```bash
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1
|
||||
...
|
||||
|
@ -29,9 +32,9 @@ print('Hello, World!')
|
|||
...
|
||||
```
|
||||
|
||||
Backtrack multiple tokens until there doesn't exist a token which can cover the prompt's suffix (`n`):
|
||||
`d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]:
|
||||
```bash
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" n
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d1
|
||||
...
|
||||
token_healing: prefix = ' worl' (2 tokens)
|
||||
[ 995] ' world'
|
||||
|
@ -46,9 +49,9 @@ print('Hello, world!')
|
|||
...
|
||||
```
|
||||
|
||||
Backtrack multiple tokens but don't constrain the decoding to a single token (`m`):
|
||||
`d`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix but allow multiple decoding steps:
|
||||
```bash
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" m
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d
|
||||
...
|
||||
token_healing: prefix = ' worl' (2 tokens)
|
||||
|
||||
|
@ -68,3 +71,35 @@ token_healing: prefix = ' worl'
|
|||
world!')
|
||||
...
|
||||
```
|
||||
|
||||
`r[N]`: Roll back `N` tokens and constrain the decoding to the bytes of those tokens (multiple decoding steps) [1].
|
||||
The paper [1] recommends `N=3`:
|
||||
```bash
|
||||
./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" r3
|
||||
...
|
||||
token_healing: prefix = ', worl' (3 tokens)
|
||||
|
||||
main: n_len = 32, n_ctx = 2048, n_kv_req = 32
|
||||
|
||||
print('Hello
|
||||
token_healing: prefix = ', worl'
|
||||
[ 11] ','
|
||||
,
|
||||
token_healing: prefix = ' worl'
|
||||
[ 220] ' '
|
||||
[ 266] ' w'
|
||||
[ 476] ' wor'
|
||||
[ 995] ' world'
|
||||
[ 8688] ' worldwide'
|
||||
[ 11621] ' worlds'
|
||||
[ 24486] ' wo'
|
||||
[ 29081] ' worldview'
|
||||
[ 43249] ' worldly'
|
||||
world!')
|
||||
...
|
||||
```
|
||||
|
||||
## Sources
|
||||
- [0] https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb
|
||||
- [1] https://arxiv.org/abs/2403.08688
|
||||
- [2] https://arxiv.org/abs/2402.01035
|
||||
|
|
|
@ -9,9 +9,10 @@
|
|||
#define TH_VERBOSE // print token healing candidates
|
||||
|
||||
enum class token_healing_type : uint8_t {
|
||||
LAST, // replace last token only
|
||||
MULTI_ONCE, // replace multiple last tokens with a single token
|
||||
MULTI // replace multiple last tokens with multiple decoding steps
|
||||
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||
};
|
||||
|
||||
struct token_healing_context {
|
||||
|
@ -21,7 +22,7 @@ struct token_healing_context {
|
|||
// TODO consider using a prefix tree
|
||||
};
|
||||
|
||||
static inline bool startswith(const std::string & str, const std::string & prefix) {
|
||||
static bool startswith(const std::string & str, const std::string & prefix) {
|
||||
return str.rfind(prefix, 0) != std::string::npos;
|
||||
}
|
||||
|
||||
|
@ -67,28 +68,31 @@ static void token_healing_free(token_healing_context * th_ctx) {
|
|||
delete th_ctx;
|
||||
}
|
||||
|
||||
static int token_healing_start(
|
||||
static int token_healing_heal(
|
||||
const llama_context * ctx,
|
||||
std::vector<llama_token> & tokens_list,
|
||||
const token_healing_type th_type,
|
||||
token_healing_context * th_ctx) {
|
||||
token_healing_context * th_ctx,
|
||||
int n_rollback = 1) {
|
||||
if (tokens_list.empty()) {
|
||||
return 0;
|
||||
}
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI;
|
||||
const int n_ctx = tokens_list.size();
|
||||
const int max_to_remove = (th_type == token_healing_type::LAST) ? 1 : n_ctx;
|
||||
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
||||
int n_removed = 0;
|
||||
std::string prefix;
|
||||
// Backtrack tokens until there does not exist a token that can cover the prompt
|
||||
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
||||
// and stop early if a special token is encountered
|
||||
while (n_removed < max_to_remove) {
|
||||
const llama_token next_token = tokens_list[n_ctx - n_removed - 1];
|
||||
if (llama_token_get_type(model, next_token) != LLAMA_TOKEN_TYPE_NORMAL) {
|
||||
const llama_token next_token_id = tokens_list[n_ctx - n_removed - 1];
|
||||
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
|
||||
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
|
||||
break;
|
||||
}
|
||||
std::string new_prefix = llama_token_to_piece(ctx, next_token) + prefix;
|
||||
if (!token_healing_prefix_exists(th_ctx, new_prefix)) {
|
||||
std::string new_prefix = th_ctx->vocab[next_token_id] + prefix;
|
||||
if (is_dynamic && !token_healing_prefix_exists(th_ctx, new_prefix)) {
|
||||
break;
|
||||
}
|
||||
n_removed += 1;
|
||||
|
@ -99,14 +103,17 @@ static int token_healing_start(
|
|||
if (n_removed == 0) {
|
||||
return 0;
|
||||
}
|
||||
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, false);
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||
const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
|
||||
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
||||
if (n_removed == 1 && candidates.size() == 1) {
|
||||
fprintf(stderr, "token_healing: nothing to heal\n");
|
||||
return 0;
|
||||
}
|
||||
#ifdef TH_VERBOSE
|
||||
if (th_type != token_healing_type::MULTI) {
|
||||
if (!is_multi_decoding) {
|
||||
// Other healing types get printed during decoding
|
||||
for (const llama_token token_id : candidates) {
|
||||
fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str());
|
||||
}
|
||||
|
@ -126,8 +133,8 @@ int main(int argc, char ** argv) {
|
|||
gpt_params params;
|
||||
|
||||
if (argc == 1 || argv[1][0] == '-') {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]);
|
||||
return 1 ;
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (argc >= 2) {
|
||||
|
@ -139,15 +146,22 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
bool token_healing_enabled = true;
|
||||
auto th_type = token_healing_type::LAST;
|
||||
auto th_type = token_healing_type::DYNAMIC_MULTI;
|
||||
int th_n_rollback = 1;
|
||||
if (argc >= 4) {
|
||||
std::string value(argv[3]);
|
||||
/**/ if (value == "0") { token_healing_enabled = false; }
|
||||
else if (value == "1") { th_type = token_healing_type::LAST; }
|
||||
else if (value == "n") { th_type = token_healing_type::MULTI_ONCE; }
|
||||
else if (value == "m") { th_type = token_healing_type::MULTI; }
|
||||
else {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|n|m]\n" , argv[0]);
|
||||
/**/ if (value == "0" ) { token_healing_enabled = false; }
|
||||
else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||
else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
token_healing_enabled = false;
|
||||
}
|
||||
} else {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
@ -201,7 +215,7 @@ int main(int argc, char ** argv) {
|
|||
token_healing_context * th_ctx = nullptr;
|
||||
if (token_healing_enabled) {
|
||||
th_ctx = token_healing_init(ctx);
|
||||
int th_n_tokens_removed = token_healing_start(ctx, tokens_list, th_type, th_ctx);
|
||||
int th_n_tokens_removed = token_healing_heal(ctx, tokens_list, th_type, th_ctx, th_n_rollback);
|
||||
if (th_n_tokens_removed == 0) {
|
||||
token_healing_enabled = false;
|
||||
}
|
||||
|
@ -267,7 +281,7 @@ int main(int argc, char ** argv) {
|
|||
// Constrain tokens based on the remaining token healing prefix
|
||||
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||
std::vector<llama_token> th_candidates;
|
||||
if (th_type == token_healing_type::LAST || th_type == token_healing_type::MULTI_ONCE) {
|
||||
if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) {
|
||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
|
||||
} else {
|
||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue