main : first attempt at token healing in main

This commit is contained in:
mare5x 2024-05-03 13:50:31 +02:00
parent 88ef908c90
commit 951b6593b2
5 changed files with 200 additions and 24 deletions

View file

@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true; return true;
} }
if (arg == "-th" || arg == "--token-healing") {
if (++i >= argc) {
invalid_param = true;
return true;
}
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
return true;
}
if (arg == "--override-kv") { if (arg == "--override-kv") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" -j SCHEMA, --json-schema SCHEMA\n");
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
printf(" Token healing type. (default: 0, disabled)\n");
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
printf(" --cfg-negative-prompt PROMPT\n"); printf(" --cfg-negative-prompt PROMPT\n");
printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n"); printf(" --cfg-negative-prompt-file FNAME\n");

View file

@ -2,6 +2,96 @@
#include "sampling.h" #include "sampling.h"
#include <random> #include <random>
//
// Token healing (internal)
//
static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos;
}
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
return true;
}
}
return false;
}
static std::vector<llama_token> token_healing_find_prefix(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
// Example: prefix=" world" -> " world", " worldwide", ...
// If `include_partial_prefix`, include also: " w", " wo", ...
std::vector<llama_token> candidates;
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix) ||
(include_partial_prefix && startswith(prefix, token))) {
candidates.push_back(token_id);
}
}
return candidates;
}
//
// Token healing (external)
//
std::string llama_token_healing_prepare(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int n_rollback) {
if (tokens.empty()) {
return "";
}
const llama_model * model = llama_get_model(ctx_main);
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens.size();
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
int n_removed = 0;
std::string prefix;
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered
while (n_removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - n_removed - 1];
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
break;
}
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
break;
}
n_removed += 1;
prefix = new_prefix;
}
if (n_removed == 0) { // E.g. if the last token is a special token
return "";
}
// If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
if (n_removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return "";
}
tokens.resize(n_ctx - n_removed);
return prefix;
}
//
// Sampling
//
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
@ -33,6 +123,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
} }
result->token_healing_prefix.clear();
result->prev.resize(params.n_prev); result->prev.resize(params.n_prev);
llama_sampling_set_rng_seed(result, params.seed); llama_sampling_set_rng_seed(result, params.seed);
@ -62,6 +154,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
} }
ctx->token_healing_prefix.clear();
std::fill(ctx->prev.begin(), ctx->prev.end(), 0); std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear(); ctx->cur.clear();
} }
@ -119,7 +213,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
} }
std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties "; std::string result = "(Token healing) -> CFG -> Penalties ";
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) { for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = sampler_type_to_name_string(sampler_type); const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
@ -297,12 +391,33 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur.clear(); cur.clear();
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
}
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
for (const llama_token token_id: th_candidates) {
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
} }
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// TODO should we skip penalties and grammar while token healing?
// apply penalties // apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
@ -361,4 +476,19 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) { if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
} }
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
if (new_token_piece.size() < th_prefix.size()) {
// Shift prefix constraint (for multi step token healing)
th_prefix = th_prefix.substr(new_token_piece.size());
} else {
// Prefix has been generated => no more constrained generation
th_prefix.clear();
LOG("token_healing: done\n");
}
}
}
} }

View file

@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't' TEMPERATURE = 't'
}; };
enum class llama_token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
std::vector<llama_token> penalty_prompt_tokens; std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false; bool use_penalty_prompt_tokens = false;
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
bool token_healing_enabled = false;
int token_healing_n_rollback = 1; // number of tokens to roll back
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context
@ -78,6 +89,8 @@ struct llama_sampling_context {
// internal // internal
grammar_parser::parse_state parsed_grammar; grammar_parser::parse_state parsed_grammar;
std::string token_healing_prefix;
// TODO: replace with ring-buffer // TODO: replace with ring-buffer
std::vector<llama_token> prev; std::vector<llama_token> prev;
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@ -152,3 +165,13 @@ void llama_sampling_accept(
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id, llama_token id,
bool apply_grammar); bool apply_grammar);
//
// Token healing
//
std::string llama_token_healing_prepare(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int n_rollback = 1);

View file

@ -264,6 +264,12 @@ int main(int argc, char ** argv) {
LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
std::string token_healing_prefix;
if (sparams.token_healing_enabled) {
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback);
}
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(model));
@ -520,6 +526,7 @@ int main(int argc, char ** argv) {
} }
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
ctx_sampling->token_healing_prefix = token_healing_prefix;
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict

View file

@ -8,13 +8,6 @@
#define TH_VERBOSE // print token healing candidates #define TH_VERBOSE // print token healing candidates
enum class token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};
struct token_healing_context { struct token_healing_context {
std::string prefix; // remaining prefix to generate (the input prompt's suffix) std::string prefix; // remaining prefix to generate (the input prompt's suffix)
@ -44,8 +37,8 @@ static std::vector<llama_token> token_healing_find_prefix(
std::vector<llama_token> candidates; std::vector<llama_token> candidates;
const auto & vocab = th_ctx->vocab; const auto & vocab = th_ctx->vocab;
for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { for (size_t token_id = 0; token_id < vocab.size(); ++token_id) {
if (startswith(vocab[token_id], prefix) if (startswith(vocab[token_id], prefix) ||
|| (include_partial_prefix && startswith(prefix, vocab[token_id]))) { (include_partial_prefix && startswith(prefix, vocab[token_id]))) {
candidates.push_back((llama_token)token_id); candidates.push_back((llama_token)token_id);
} }
} }
@ -71,14 +64,14 @@ static void token_healing_free(token_healing_context * th_ctx) {
static int token_healing_heal( static int token_healing_heal(
const llama_context * ctx, const llama_context * ctx,
std::vector<llama_token> & tokens_list, std::vector<llama_token> & tokens_list,
const token_healing_type th_type, const llama_token_healing_type th_type,
token_healing_context * th_ctx, token_healing_context * th_ctx,
int n_rollback = 1) { int n_rollback = 1) {
if (tokens_list.empty()) { if (tokens_list.empty()) {
return 0; return 0;
} }
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI; const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens_list.size(); const int n_ctx = tokens_list.size();
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
int n_removed = 0; int n_removed = 0;
@ -104,7 +97,7 @@ static int token_healing_heal(
return 0; return 0;
} }
// If constrained decoding would give back the original prompt, there is no need to modify the context // If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI; const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
if (n_removed == 1 && candidates.size() == 1) { if (n_removed == 1 && candidates.size() == 1) {
@ -119,9 +112,7 @@ static int token_healing_heal(
} }
} }
#endif #endif
for (int i = 0; i < n_removed; ++i) { tokens_list.resize(n_ctx - n_removed);
tokens_list.pop_back();
}
if (tokens_list.empty()) { if (tokens_list.empty()) {
// If the first token was removed, llama_decode would crash with an empty sequence, so add bos. // If the first token was removed, llama_decode would crash with an empty sequence, so add bos.
tokens_list.emplace_back(llama_token_bos(model)); tokens_list.emplace_back(llama_token_bos(model));
@ -146,16 +137,16 @@ int main(int argc, char ** argv) {
} }
bool token_healing_enabled = true; bool token_healing_enabled = true;
auto th_type = token_healing_type::DYNAMIC_MULTI; auto th_type = llama_token_healing_type::DYNAMIC_MULTI;
int th_n_rollback = 1; int th_n_rollback = 1;
if (argc >= 4) { if (argc >= 4) {
std::string value(argv[3]); std::string value(argv[3]);
/**/ if (value == "0" ) { token_healing_enabled = false; } /**/ if (value == "0" ) { token_healing_enabled = false; }
else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; } else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; } else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) { else if (value[0] == 'r' ) {
th_type = token_healing_type::ROLLBACK_MULTI; th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1)); th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) { if (th_n_rollback <= 0) {
token_healing_enabled = false; token_healing_enabled = false;
@ -281,7 +272,7 @@ int main(int argc, char ** argv) {
// Constrain tokens based on the remaining token healing prefix // Constrain tokens based on the remaining token healing prefix
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf // N.B. We could also set token constraints by setting rejected tokens' logits to -inf
std::vector<llama_token> th_candidates; std::vector<llama_token> th_candidates;
if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) { if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) {
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
} else { } else {
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);