token healing : change dynamic rollback

Dynamic rollback now starts checking prefixes based on the length of the longest token.
This commit is contained in:
mare5x 2024-06-29 13:02:30 +02:00
parent 13885c747e
commit db9c018891
2 changed files with 95 additions and 43 deletions

View file

@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) {
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix)) {
return true;
}
}
return false;
}
static std::vector<llama_token> token_healing_find_prefix(
static std::vector<llama_token> token_healing_get_candidates(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
@ -38,6 +39,85 @@ static std::vector<llama_token> token_healing_find_prefix(
return candidates;
}
static size_t get_max_token_length(const llama_context * ctx_main) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
size_t len = 0;
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
len = std::max(len, token.size());
}
return len;
}
struct token_healing_info {
std::string prefix;
int n_tokens_removed;
};
token_healing_info llama_token_healing_get_prefix(
const llama_context * ctx_main,
const llama_token_healing_type th_type,
const std::vector<llama_token> & tokens,
int max_to_remove) {
if (tokens.size() <= 1) {
return {"", 0};
}
const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
int removed = 0;
std::string prefix;
const llama_model * model = llama_get_model(ctx_main);
auto is_special_token = [&](const llama_token token_id) {
return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id);
};
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
// The number of bytes to roll back cannot exceed the length of the longest token.
const size_t n_longest_token = get_max_token_length(ctx_main);
size_t len = 0;
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (is_special_token(next_token_id)) {
break;
}
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
if (len + next_token_size > n_longest_token) {
break;
}
len += next_token_size;
removed += 1;
}
while (removed > 0) {
prefix.clear();
for (int i = n_ctx - removed; i < n_ctx; i++) {
prefix += llama_token_to_piece(ctx_main, tokens[i]);
}
if (token_healing_prefix_exists(ctx_main, prefix)) {
break; // Stop on longest valid prefix
}
removed -= 1;
}
} else {
// Roll back tokens a fixed amount and stop early if a special token is encountered.
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (is_special_token(next_token_id)) {
break;
}
removed += 1;
}
for (int i = n_ctx - removed; i < n_ctx; i++) {
prefix += llama_token_to_piece(ctx_main, tokens[i]);
}
}
return {prefix, removed};
}
//
// Token healing (external)
//
@ -48,56 +128,28 @@ std::string llama_token_healing_rollback(
std::vector<llama_token> & tokens,
int max_to_remove,
int * n_removed) {
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
if (n_removed != nullptr) {
*n_removed = 0;
}
if (tokens.size() <= 1) {
return "";
}
const llama_model * model = llama_get_model(ctx_main);
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
int removed = 0;
std::string prefix;
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered.
// NB. This doesn't handle cases where a long token is split many times,
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
// then "abc" will not be returned even if "abcd" exists in the vocab.
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
break; // Don't roll back e.g. <|endoftext|>
}
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
break;
}
removed += 1;
prefix = new_prefix;
}
if (removed == 0) { // E.g. if the last token is a special token
return "";
}
// If constrained decoding would give back the original prompt, there is no need to modify the context
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
if (removed == 1 && candidates.size() == 1) {
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return "";
}
// Finalize outputs
if (n_removed != nullptr) {
*n_removed = removed;
*n_removed = info.n_tokens_removed;
}
tokens.resize(n_ctx - removed);
return prefix;
tokens.resize(tokens.size() - info.n_tokens_removed);
return info.prefix;
}
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
@ -507,7 +559,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {

View file

@ -293,7 +293,7 @@ int main(int argc, char ** argv) {
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false;
LOG("token_healing: disabled due to custom suffix/conversation mode");
LOG("token healing: disabled due to custom suffix/conversation mode");
}
std::string token_healing_prefix;
int token_healing_n_removed = 0;