main : add token healing

This commit is contained in:
mare5x 2024-06-27 16:08:24 +02:00
parent 272e3bd95e
commit 13885c747e
5 changed files with 249 additions and 6 deletions

View file

@ -1093,6 +1093,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true;
}
if (arg == "-th" || arg == "--token-healing") {
CHECK_ARG
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
return true;
}
if (arg == "--override-kv") {
CHECK_ARG
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
@ -1501,6 +1520,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted:\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
"Token healing type. (default: 0, disabled)\n"
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });

View file

@ -2,6 +2,112 @@
#include "sampling.h"
#include <random>
//
// Token healing (internal)
//
static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos;
}
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
return true;
}
}
return false;
}
static std::vector<llama_token> token_healing_find_prefix(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
// Example: prefix=" world" -> " world", " worldwide", ...
// If `include_partial_prefix`, include also: " w", " wo", ...
std::vector<llama_token> candidates;
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix) ||
(include_partial_prefix && startswith(prefix, token))) {
candidates.push_back(token_id);
}
}
return candidates;
}
//
// Token healing (external)
//
std::string llama_token_healing_rollback(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove,
int * n_removed) {
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
if (n_removed != nullptr) {
*n_removed = 0;
}
if (tokens.size() <= 1) {
return "";
}
const llama_model * model = llama_get_model(ctx_main);
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
int removed = 0;
std::string prefix;
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered.
// NB. This doesn't handle cases where a long token is split many times,
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
// then "abc" will not be returned even if "abcd" exists in the vocab.
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
break; // Don't roll back e.g. <|endoftext|>
}
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
break;
}
removed += 1;
prefix = new_prefix;
}
if (removed == 0) { // E.g. if the last token is a special token
return "";
}
// If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
if (removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return "";
}
// Finalize outputs
if (n_removed != nullptr) {
*n_removed = removed;
}
tokens.resize(n_ctx - removed);
return prefix;
}
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
ctx_sampling->token_healing_prefix = prefix;
}
//
// Sampling
//
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@ -72,6 +178,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->grammar = grammar;
}
ctx->token_healing_prefix.clear();
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_valid = 0;
@ -130,7 +238,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
}
std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
std::string result = "(Token healing) -> CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
@ -393,8 +501,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
}
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
for (const llama_token token_id : th_candidates) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@ -457,4 +584,19 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
if (new_token_piece.size() < th_prefix.size()) {
// Shift prefix constraint (for multi step token healing)
th_prefix = th_prefix.substr(new_token_piece.size());
} else {
// Prefix has been generated => no more constrained generation
th_prefix.clear();
LOG("token_healing: done\n");
}
}
}
}

View file

@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't'
};
enum class llama_token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
bool token_healing_enabled = false;
int token_healing_n_rollback = -1; // number of tokens to roll back
} llama_sampling_params;
// general sampler context
@ -78,6 +89,8 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;
std::string token_healing_prefix; // remaining prefix to constrain sampling
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
@ -158,3 +171,18 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);
//
// Token healing
//
// Roll back `tokens` for constrained generation according to the token healing
// strategy. Returns the prefix for constrained generation.
std::string llama_token_healing_rollback(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove = -1,
int * n_removed = nullptr);
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

View file

@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
Example usage: `--logit-bias 29905-inf`
### Token healing
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
### RNG Seed
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

View file

@ -291,6 +291,17 @@ int main(int argc, char ** argv) {
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false;
LOG("token_healing: disabled due to custom suffix/conversation mode");
}
std::string token_healing_prefix;
int token_healing_n_removed = 0;
if (!params.interactive_first && sparams.token_healing_enabled) {
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback, &token_healing_n_removed);
}
// Should not run without any tokens
if (embd_inp.empty()) {
if (add_bos) {
@ -315,7 +326,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
original_prompt_len = original_inp.size() - token_healing_n_removed;
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@ -510,6 +521,7 @@ int main(int argc, char ** argv) {
int n_consumed = 0;
int n_session_consumed = 0;
int n_past_guidance = 0;
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@ -536,6 +548,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
@ -770,7 +783,15 @@ int main(int argc, char ** argv) {
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
// Console/Stream Output
fprintf(stdout, "%s", token_str.c_str());
// Suppress printing while generating token healing prefix
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str());
n_bytes_to_skip = 0;
} else if (n_bytes_to_skip > 0) {
n_bytes_to_skip -= token_str.size();
} else {
fprintf(stdout, "%s", token_str.c_str());
}
// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
@ -862,6 +883,7 @@ int main(int argc, char ** argv) {
assistant_ss << llama_token_to_piece(ctx, id, false);
}
token_healing_n_removed = 0;
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
@ -934,6 +956,17 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
if (sparams.token_healing_enabled) {
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
const int n_new_tokens = embd_inp.size() - original_size;
const int max_to_remove = sparams.token_healing_n_rollback < 0
? n_new_tokens
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
max_to_remove, &token_healing_n_removed);
n_bytes_to_skip = token_healing_prefix.size();
}
for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
@ -943,7 +976,7 @@ int main(int argc, char ** argv) {
// reset assistant message
assistant_ss.str("");
n_remain -= line_inp.size();
n_remain -= line_inp.size() + token_healing_n_removed;
LOG("n_remain: %d\n", n_remain);
} else {
LOG("empty line, passing control back\n");
@ -955,6 +988,10 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
if (token_healing_n_removed > 0) {
// Set new prefix after an interaction
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
}
}
is_interacting = false;
}