token healing : change dynamic rollback
Dynamic rollback now starts checking prefixes based on the length of the longest token.
This commit is contained in:
parent
13885c747e
commit
db9c018891
2 changed files with 95 additions and 43 deletions
|
@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) {
|
|||
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
if (startswith(token, prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<llama_token> token_healing_find_prefix(
|
||||
static std::vector<llama_token> token_healing_get_candidates(
|
||||
const llama_context * ctx_main,
|
||||
const std::string & prefix,
|
||||
const bool include_partial_prefix) {
|
||||
|
@ -38,6 +39,85 @@ static std::vector<llama_token> token_healing_find_prefix(
|
|||
return candidates;
|
||||
}
|
||||
|
||||
static size_t get_max_token_length(const llama_context * ctx_main) {
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
size_t len = 0;
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
len = std::max(len, token.size());
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
struct token_healing_info {
|
||||
std::string prefix;
|
||||
int n_tokens_removed;
|
||||
};
|
||||
|
||||
token_healing_info llama_token_healing_get_prefix(
|
||||
const llama_context * ctx_main,
|
||||
const llama_token_healing_type th_type,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int max_to_remove) {
|
||||
if (tokens.size() <= 1) {
|
||||
return {"", 0};
|
||||
}
|
||||
|
||||
const int n_ctx = tokens.size();
|
||||
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
||||
|
||||
int removed = 0;
|
||||
std::string prefix;
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_main);
|
||||
auto is_special_token = [&](const llama_token token_id) {
|
||||
return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id);
|
||||
};
|
||||
|
||||
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
|
||||
// The number of bytes to roll back cannot exceed the length of the longest token.
|
||||
const size_t n_longest_token = get_max_token_length(ctx_main);
|
||||
size_t len = 0;
|
||||
while (removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||
if (is_special_token(next_token_id)) {
|
||||
break;
|
||||
}
|
||||
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
|
||||
if (len + next_token_size > n_longest_token) {
|
||||
break;
|
||||
}
|
||||
len += next_token_size;
|
||||
removed += 1;
|
||||
}
|
||||
|
||||
while (removed > 0) {
|
||||
prefix.clear();
|
||||
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||
}
|
||||
if (token_healing_prefix_exists(ctx_main, prefix)) {
|
||||
break; // Stop on longest valid prefix
|
||||
}
|
||||
removed -= 1;
|
||||
}
|
||||
} else {
|
||||
// Roll back tokens a fixed amount and stop early if a special token is encountered.
|
||||
while (removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||
if (is_special_token(next_token_id)) {
|
||||
break;
|
||||
}
|
||||
removed += 1;
|
||||
}
|
||||
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||
}
|
||||
}
|
||||
return {prefix, removed};
|
||||
}
|
||||
|
||||
//
|
||||
// Token healing (external)
|
||||
//
|
||||
|
@ -48,56 +128,28 @@ std::string llama_token_healing_rollback(
|
|||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove,
|
||||
int * n_removed) {
|
||||
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||
if (n_removed != nullptr) {
|
||||
*n_removed = 0;
|
||||
}
|
||||
if (tokens.size() <= 1) {
|
||||
return "";
|
||||
}
|
||||
const llama_model * model = llama_get_model(ctx_main);
|
||||
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const int n_ctx = tokens.size();
|
||||
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
||||
int removed = 0;
|
||||
std::string prefix;
|
||||
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
||||
// and stop early if a special token is encountered.
|
||||
// NB. This doesn't handle cases where a long token is split many times,
|
||||
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
|
||||
// then "abc" will not be returned even if "abcd" exists in the vocab.
|
||||
while (removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
|
||||
break; // Don't roll back e.g. <|endoftext|>
|
||||
}
|
||||
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
|
||||
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
||||
break;
|
||||
}
|
||||
removed += 1;
|
||||
prefix = new_prefix;
|
||||
}
|
||||
if (removed == 0) { // E.g. if the last token is a special token
|
||||
return "";
|
||||
}
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
|
||||
if (removed == 1 && candidates.size() == 1) {
|
||||
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
|
||||
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||
LOG("token_healing: nothing to heal\n");
|
||||
return "";
|
||||
}
|
||||
// Finalize outputs
|
||||
if (n_removed != nullptr) {
|
||||
*n_removed = removed;
|
||||
*n_removed = info.n_tokens_removed;
|
||||
}
|
||||
tokens.resize(n_ctx - removed);
|
||||
return prefix;
|
||||
tokens.resize(tokens.size() - info.n_tokens_removed);
|
||||
return info.prefix;
|
||||
}
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||
|
@ -507,7 +559,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
if (params.token_healing_enabled && !th_prefix.empty()) {
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
|
||||
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||
|
||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||
for (const llama_token token_id : th_candidates) {
|
||||
|
|
|
@ -293,7 +293,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||
sparams.token_healing_enabled = false;
|
||||
LOG("token_healing: disabled due to custom suffix/conversation mode");
|
||||
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||
}
|
||||
std::string token_healing_prefix;
|
||||
int token_healing_n_removed = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue