main : first attempt at token healing in main
This commit is contained in:
parent
88ef908c90
commit
951b6593b2
5 changed files with 200 additions and 24 deletions
|
@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
||||
return true;
|
||||
}
|
||||
if (arg == "-th" || arg == "--token-healing") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
sparams.token_healing_enabled = true;
|
||||
auto & th_type = sparams.token_healing_type;
|
||||
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
sparams.token_healing_enabled = false;
|
||||
}
|
||||
} else { invalid_param = true; }
|
||||
return true;
|
||||
}
|
||||
if (arg == "--override-kv") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" -j SCHEMA, --json-schema SCHEMA\n");
|
||||
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
|
||||
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
|
||||
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
|
||||
printf(" Token healing type. (default: 0, disabled)\n");
|
||||
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
|
||||
printf(" --cfg-negative-prompt PROMPT\n");
|
||||
printf(" negative prompt to use for guidance. (default: empty)\n");
|
||||
printf(" --cfg-negative-prompt-file FNAME\n");
|
||||
|
|
|
@ -2,6 +2,96 @@
|
|||
#include "sampling.h"
|
||||
#include <random>
|
||||
|
||||
//
|
||||
// Token healing (internal)
|
||||
//
|
||||
|
||||
static bool startswith(const std::string & str, const std::string & prefix) {
|
||||
return str.rfind(prefix, 0) != std::string::npos;
|
||||
}
|
||||
|
||||
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<llama_token> token_healing_find_prefix(
|
||||
const llama_context * ctx_main,
|
||||
const std::string & prefix,
|
||||
const bool include_partial_prefix) {
|
||||
// Example: prefix=" world" -> " world", " worldwide", ...
|
||||
// If `include_partial_prefix`, include also: " w", " wo", ...
|
||||
std::vector<llama_token> candidates;
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
if (startswith(token, prefix) ||
|
||||
(include_partial_prefix && startswith(prefix, token))) {
|
||||
candidates.push_back(token_id);
|
||||
}
|
||||
}
|
||||
return candidates;
|
||||
}
|
||||
|
||||
//
|
||||
// Token healing (external)
|
||||
//
|
||||
|
||||
std::string llama_token_healing_prepare(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int n_rollback) {
|
||||
if (tokens.empty()) {
|
||||
return "";
|
||||
}
|
||||
const llama_model * model = llama_get_model(ctx_main);
|
||||
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const int n_ctx = tokens.size();
|
||||
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
||||
int n_removed = 0;
|
||||
std::string prefix;
|
||||
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
||||
// and stop early if a special token is encountered
|
||||
while (n_removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - n_removed - 1];
|
||||
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
|
||||
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
|
||||
break;
|
||||
}
|
||||
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
|
||||
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
||||
break;
|
||||
}
|
||||
n_removed += 1;
|
||||
prefix = new_prefix;
|
||||
}
|
||||
|
||||
if (n_removed == 0) { // E.g. if the last token is a special token
|
||||
return "";
|
||||
}
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
||||
if (n_removed == 1 && candidates.size() == 1) {
|
||||
LOG("token_healing: nothing to heal\n");
|
||||
return "";
|
||||
}
|
||||
tokens.resize(n_ctx - n_removed);
|
||||
return prefix;
|
||||
}
|
||||
|
||||
//
|
||||
// Sampling
|
||||
//
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
|
@ -33,6 +123,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
result->token_healing_prefix.clear();
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
@ -62,6 +154,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
|||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
ctx->token_healing_prefix.clear();
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
}
|
||||
|
@ -119,7 +213,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||
}
|
||||
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
std::string result = "(Token healing) -> CFG -> Penalties ";
|
||||
if (params.mirostat == 0) {
|
||||
for (auto sampler_type : params.samplers_sequence) {
|
||||
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
|
||||
|
@ -297,12 +391,33 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
cur.clear();
|
||||
|
||||
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||
const auto & th_type = params.token_healing_type;
|
||||
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (params.token_healing_enabled && !th_prefix.empty()) {
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
|
||||
|
||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||
for (const llama_token token_id : th_candidates) {
|
||||
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
|
||||
}
|
||||
|
||||
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||
for (const llama_token token_id: th_candidates) {
|
||||
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
// TODO should we skip penalties and grammar while token healing?
|
||||
|
||||
// apply penalties
|
||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||
|
@ -361,4 +476,19 @@ void llama_sampling_accept(
|
|||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||
}
|
||||
|
||||
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
|
||||
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (!th_prefix.empty()) {
|
||||
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||
if (new_token_piece.size() < th_prefix.size()) {
|
||||
// Shift prefix constraint (for multi step token healing)
|
||||
th_prefix = th_prefix.substr(new_token_piece.size());
|
||||
} else {
|
||||
// Prefix has been generated => no more constrained generation
|
||||
th_prefix.clear();
|
||||
LOG("token_healing: done\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
|
|||
TEMPERATURE = 't'
|
||||
};
|
||||
|
||||
enum class llama_token_healing_type : uint8_t {
|
||||
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
|
@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
|
|||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
|
||||
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
||||
bool token_healing_enabled = false;
|
||||
int token_healing_n_rollback = 1; // number of tokens to roll back
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
|
@ -78,6 +89,8 @@ struct llama_sampling_context {
|
|||
// internal
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
std::string token_healing_prefix;
|
||||
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
|
@ -152,3 +165,13 @@ void llama_sampling_accept(
|
|||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
//
|
||||
// Token healing
|
||||
//
|
||||
|
||||
std::string llama_token_healing_prepare(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int n_rollback = 1);
|
||||
|
|
|
@ -264,6 +264,12 @@ int main(int argc, char ** argv) {
|
|||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
|
||||
std::string token_healing_prefix;
|
||||
if (sparams.token_healing_enabled) {
|
||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||
sparams.token_healing_n_rollback);
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
|
@ -520,6 +526,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
|
|
|
@ -8,13 +8,6 @@
|
|||
|
||||
#define TH_VERBOSE // print token healing candidates
|
||||
|
||||
enum class token_healing_type : uint8_t {
|
||||
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||
};
|
||||
|
||||
struct token_healing_context {
|
||||
std::string prefix; // remaining prefix to generate (the input prompt's suffix)
|
||||
|
||||
|
@ -44,8 +37,8 @@ static std::vector<llama_token> token_healing_find_prefix(
|
|||
std::vector<llama_token> candidates;
|
||||
const auto & vocab = th_ctx->vocab;
|
||||
for (size_t token_id = 0; token_id < vocab.size(); ++token_id) {
|
||||
if (startswith(vocab[token_id], prefix)
|
||||
|| (include_partial_prefix && startswith(prefix, vocab[token_id]))) {
|
||||
if (startswith(vocab[token_id], prefix) ||
|
||||
(include_partial_prefix && startswith(prefix, vocab[token_id]))) {
|
||||
candidates.push_back((llama_token)token_id);
|
||||
}
|
||||
}
|
||||
|
@ -71,14 +64,14 @@ static void token_healing_free(token_healing_context * th_ctx) {
|
|||
static int token_healing_heal(
|
||||
const llama_context * ctx,
|
||||
std::vector<llama_token> & tokens_list,
|
||||
const token_healing_type th_type,
|
||||
const llama_token_healing_type th_type,
|
||||
token_healing_context * th_ctx,
|
||||
int n_rollback = 1) {
|
||||
if (tokens_list.empty()) {
|
||||
return 0;
|
||||
}
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI;
|
||||
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const int n_ctx = tokens_list.size();
|
||||
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
||||
int n_removed = 0;
|
||||
|
@ -104,7 +97,7 @@ static int token_healing_heal(
|
|||
return 0;
|
||||
}
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||
const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI;
|
||||
const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
|
||||
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
||||
if (n_removed == 1 && candidates.size() == 1) {
|
||||
|
@ -119,9 +112,7 @@ static int token_healing_heal(
|
|||
}
|
||||
}
|
||||
#endif
|
||||
for (int i = 0; i < n_removed; ++i) {
|
||||
tokens_list.pop_back();
|
||||
}
|
||||
tokens_list.resize(n_ctx - n_removed);
|
||||
if (tokens_list.empty()) {
|
||||
// If the first token was removed, llama_decode would crash with an empty sequence, so add bos.
|
||||
tokens_list.emplace_back(llama_token_bos(model));
|
||||
|
@ -146,16 +137,16 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
bool token_healing_enabled = true;
|
||||
auto th_type = token_healing_type::DYNAMIC_MULTI;
|
||||
auto th_type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||
int th_n_rollback = 1;
|
||||
if (argc >= 4) {
|
||||
std::string value(argv[3]);
|
||||
/**/ if (value == "0" ) { token_healing_enabled = false; }
|
||||
else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||
else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = token_healing_type::ROLLBACK_MULTI;
|
||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
token_healing_enabled = false;
|
||||
|
@ -281,7 +272,7 @@ int main(int argc, char ** argv) {
|
|||
// Constrain tokens based on the remaining token healing prefix
|
||||
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||
std::vector<llama_token> th_candidates;
|
||||
if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) {
|
||||
if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) {
|
||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
|
||||
} else {
|
||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue