token healing : change dynamic rollback
Dynamic rollback now starts checking prefixes based on the length of the longest token.
This commit is contained in:
parent
13885c747e
commit
db9c018891
2 changed files with 95 additions and 43 deletions
|
@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) {
|
||||||
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
if (startswith(token, prefix)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llama_token> token_healing_find_prefix(
|
static std::vector<llama_token> token_healing_get_candidates(
|
||||||
const llama_context * ctx_main,
|
const llama_context * ctx_main,
|
||||||
const std::string & prefix,
|
const std::string & prefix,
|
||||||
const bool include_partial_prefix) {
|
const bool include_partial_prefix) {
|
||||||
|
@ -38,6 +39,85 @@ static std::vector<llama_token> token_healing_find_prefix(
|
||||||
return candidates;
|
return candidates;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t get_max_token_length(const llama_context * ctx_main) {
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
size_t len = 0;
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
len = std::max(len, token.size());
|
||||||
|
}
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct token_healing_info {
|
||||||
|
std::string prefix;
|
||||||
|
int n_tokens_removed;
|
||||||
|
};
|
||||||
|
|
||||||
|
token_healing_info llama_token_healing_get_prefix(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
const llama_token_healing_type th_type,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
int max_to_remove) {
|
||||||
|
if (tokens.size() <= 1) {
|
||||||
|
return {"", 0};
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_ctx = tokens.size();
|
||||||
|
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||||
|
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
||||||
|
|
||||||
|
int removed = 0;
|
||||||
|
std::string prefix;
|
||||||
|
|
||||||
|
const llama_model * model = llama_get_model(ctx_main);
|
||||||
|
auto is_special_token = [&](const llama_token token_id) {
|
||||||
|
return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
|
||||||
|
// The number of bytes to roll back cannot exceed the length of the longest token.
|
||||||
|
const size_t n_longest_token = get_max_token_length(ctx_main);
|
||||||
|
size_t len = 0;
|
||||||
|
while (removed < max_to_remove) {
|
||||||
|
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||||
|
if (is_special_token(next_token_id)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
|
||||||
|
if (len + next_token_size > n_longest_token) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
len += next_token_size;
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (removed > 0) {
|
||||||
|
prefix.clear();
|
||||||
|
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||||
|
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||||
|
}
|
||||||
|
if (token_healing_prefix_exists(ctx_main, prefix)) {
|
||||||
|
break; // Stop on longest valid prefix
|
||||||
|
}
|
||||||
|
removed -= 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Roll back tokens a fixed amount and stop early if a special token is encountered.
|
||||||
|
while (removed < max_to_remove) {
|
||||||
|
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||||
|
if (is_special_token(next_token_id)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||||
|
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {prefix, removed};
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Token healing (external)
|
// Token healing (external)
|
||||||
//
|
//
|
||||||
|
@ -48,56 +128,28 @@ std::string llama_token_healing_rollback(
|
||||||
std::vector<llama_token> & tokens,
|
std::vector<llama_token> & tokens,
|
||||||
int max_to_remove,
|
int max_to_remove,
|
||||||
int * n_removed) {
|
int * n_removed) {
|
||||||
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
|
||||||
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
|
||||||
if (n_removed != nullptr) {
|
if (n_removed != nullptr) {
|
||||||
*n_removed = 0;
|
*n_removed = 0;
|
||||||
}
|
}
|
||||||
if (tokens.size() <= 1) {
|
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||||
return "";
|
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||||
}
|
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||||
const llama_model * model = llama_get_model(ctx_main);
|
|
||||||
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
|
||||||
const int n_ctx = tokens.size();
|
|
||||||
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
|
||||||
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
|
||||||
int removed = 0;
|
|
||||||
std::string prefix;
|
|
||||||
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
|
||||||
// and stop early if a special token is encountered.
|
|
||||||
// NB. This doesn't handle cases where a long token is split many times,
|
|
||||||
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
|
|
||||||
// then "abc" will not be returned even if "abcd" exists in the vocab.
|
|
||||||
while (removed < max_to_remove) {
|
|
||||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
|
||||||
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
|
|
||||||
break; // Don't roll back e.g. <|endoftext|>
|
|
||||||
}
|
|
||||||
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
|
|
||||||
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
removed += 1;
|
|
||||||
prefix = new_prefix;
|
|
||||||
}
|
|
||||||
if (removed == 0) { // E.g. if the last token is a special token
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
|
||||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
|
||||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
|
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
|
||||||
if (removed == 1 && candidates.size() == 1) {
|
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||||
LOG("token_healing: nothing to heal\n");
|
LOG("token_healing: nothing to heal\n");
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
// Finalize outputs
|
// Finalize outputs
|
||||||
if (n_removed != nullptr) {
|
if (n_removed != nullptr) {
|
||||||
*n_removed = removed;
|
*n_removed = info.n_tokens_removed;
|
||||||
}
|
}
|
||||||
tokens.resize(n_ctx - removed);
|
tokens.resize(tokens.size() - info.n_tokens_removed);
|
||||||
return prefix;
|
return info.prefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||||
|
@ -507,7 +559,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
if (params.token_healing_enabled && !th_prefix.empty()) {
|
if (params.token_healing_enabled && !th_prefix.empty()) {
|
||||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
|
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||||
|
|
||||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||||
for (const llama_token token_id : th_candidates) {
|
for (const llama_token token_id : th_candidates) {
|
||||||
|
|
|
@ -293,7 +293,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
|
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||||
sparams.token_healing_enabled = false;
|
sparams.token_healing_enabled = false;
|
||||||
LOG("token_healing: disabled due to custom suffix/conversation mode");
|
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||||
}
|
}
|
||||||
std::string token_healing_prefix;
|
std::string token_healing_prefix;
|
||||||
int token_healing_n_removed = 0;
|
int token_healing_n_removed = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue