main : first attempt at token healing in main
This commit is contained in:
parent
88ef908c90
commit
951b6593b2
5 changed files with 200 additions and 24 deletions
|
@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-th" || arg == "--token-healing") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
sparams.token_healing_enabled = true;
|
||||||
|
auto & th_type = sparams.token_healing_type;
|
||||||
|
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
||||||
|
std::string value(argv[i]);
|
||||||
|
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
||||||
|
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||||
|
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
|
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
|
else if (value[0] == 'r' ) {
|
||||||
|
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
|
th_n_rollback = std::stoi(value.substr(1));
|
||||||
|
if (th_n_rollback <= 0) {
|
||||||
|
sparams.token_healing_enabled = false;
|
||||||
|
}
|
||||||
|
} else { invalid_param = true; }
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--override-kv") {
|
if (arg == "--override-kv") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" -j SCHEMA, --json-schema SCHEMA\n");
|
printf(" -j SCHEMA, --json-schema SCHEMA\n");
|
||||||
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
|
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
|
||||||
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
|
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
|
||||||
|
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
|
||||||
|
printf(" Token healing type. (default: 0, disabled)\n");
|
||||||
|
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
|
||||||
printf(" --cfg-negative-prompt PROMPT\n");
|
printf(" --cfg-negative-prompt PROMPT\n");
|
||||||
printf(" negative prompt to use for guidance. (default: empty)\n");
|
printf(" negative prompt to use for guidance. (default: empty)\n");
|
||||||
printf(" --cfg-negative-prompt-file FNAME\n");
|
printf(" --cfg-negative-prompt-file FNAME\n");
|
||||||
|
|
|
@ -2,6 +2,96 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing (internal)
|
||||||
|
//
|
||||||
|
|
||||||
|
static bool startswith(const std::string & str, const std::string & prefix) {
|
||||||
|
return str.rfind(prefix, 0) != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_token> token_healing_find_prefix(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
const std::string & prefix,
|
||||||
|
const bool include_partial_prefix) {
|
||||||
|
// Example: prefix=" world" -> " world", " worldwide", ...
|
||||||
|
// If `include_partial_prefix`, include also: " w", " wo", ...
|
||||||
|
std::vector<llama_token> candidates;
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
if (startswith(token, prefix) ||
|
||||||
|
(include_partial_prefix && startswith(prefix, token))) {
|
||||||
|
candidates.push_back(token_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return candidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing (external)
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string llama_token_healing_prepare(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
llama_token_healing_type th_type,
|
||||||
|
std::vector<llama_token> & tokens,
|
||||||
|
int n_rollback) {
|
||||||
|
if (tokens.empty()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
const llama_model * model = llama_get_model(ctx_main);
|
||||||
|
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
const int n_ctx = tokens.size();
|
||||||
|
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
||||||
|
int n_removed = 0;
|
||||||
|
std::string prefix;
|
||||||
|
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
||||||
|
// and stop early if a special token is encountered
|
||||||
|
while (n_removed < max_to_remove) {
|
||||||
|
const llama_token next_token_id = tokens[n_ctx - n_removed - 1];
|
||||||
|
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
|
||||||
|
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
|
||||||
|
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
n_removed += 1;
|
||||||
|
prefix = new_prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_removed == 0) { // E.g. if the last token is a special token
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
||||||
|
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
||||||
|
if (n_removed == 1 && candidates.size() == 1) {
|
||||||
|
LOG("token_healing: nothing to heal\n");
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
tokens.resize(n_ctx - n_removed);
|
||||||
|
return prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Sampling
|
||||||
|
//
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
|
@ -33,6 +123,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result->token_healing_prefix.clear();
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
llama_sampling_set_rng_seed(result, params.seed);
|
llama_sampling_set_rng_seed(result, params.seed);
|
||||||
|
@ -62,6 +154,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx->token_healing_prefix.clear();
|
||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||||
ctx->cur.clear();
|
ctx->cur.clear();
|
||||||
}
|
}
|
||||||
|
@ -119,7 +213,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "(Token healing) -> CFG -> Penalties ";
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (auto sampler_type : params.samplers_sequence) {
|
for (auto sampler_type : params.samplers_sequence) {
|
||||||
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
|
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
|
||||||
|
@ -297,12 +391,33 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
cur.clear();
|
cur.clear();
|
||||||
|
|
||||||
|
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||||
|
const auto & th_type = params.token_healing_type;
|
||||||
|
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
|
if (params.token_healing_enabled && !th_prefix.empty()) {
|
||||||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
|
||||||
|
|
||||||
|
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||||
|
for (const llama_token token_id : th_candidates) {
|
||||||
|
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||||
|
for (const llama_token token_id: th_candidates) {
|
||||||
|
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
|
// TODO should we skip penalties and grammar while token healing?
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
|
@ -361,4 +476,19 @@ void llama_sampling_accept(
|
||||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
|
||||||
|
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
|
if (!th_prefix.empty()) {
|
||||||
|
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||||
|
if (new_token_piece.size() < th_prefix.size()) {
|
||||||
|
// Shift prefix constraint (for multi step token healing)
|
||||||
|
th_prefix = th_prefix.substr(new_token_piece.size());
|
||||||
|
} else {
|
||||||
|
// Prefix has been generated => no more constrained generation
|
||||||
|
th_prefix.clear();
|
||||||
|
LOG("token_healing: done\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
|
||||||
TEMPERATURE = 't'
|
TEMPERATURE = 't'
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class llama_token_healing_type : uint8_t {
|
||||||
|
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||||
|
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||||
|
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||||
|
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
|
|
||||||
|
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
||||||
|
bool token_healing_enabled = false;
|
||||||
|
int token_healing_n_rollback = 1; // number of tokens to roll back
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -78,6 +89,8 @@ struct llama_sampling_context {
|
||||||
// internal
|
// internal
|
||||||
grammar_parser::parse_state parsed_grammar;
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
|
||||||
|
std::string token_healing_prefix;
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> prev;
|
std::vector<llama_token> prev;
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
@ -152,3 +165,13 @@ void llama_sampling_accept(
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar);
|
bool apply_grammar);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing
|
||||||
|
//
|
||||||
|
|
||||||
|
std::string llama_token_healing_prepare(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
llama_token_healing_type th_type,
|
||||||
|
std::vector<llama_token> & tokens,
|
||||||
|
int n_rollback = 1);
|
||||||
|
|
|
@ -264,6 +264,12 @@ int main(int argc, char ** argv) {
|
||||||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
||||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
|
|
||||||
|
std::string token_healing_prefix;
|
||||||
|
if (sparams.token_healing_enabled) {
|
||||||
|
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||||
|
sparams.token_healing_n_rollback);
|
||||||
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
if (embd_inp.empty()) {
|
if (embd_inp.empty()) {
|
||||||
embd_inp.push_back(llama_token_bos(model));
|
embd_inp.push_back(llama_token_bos(model));
|
||||||
|
@ -520,6 +526,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
|
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
|
|
|
@ -8,13 +8,6 @@
|
||||||
|
|
||||||
#define TH_VERBOSE // print token healing candidates
|
#define TH_VERBOSE // print token healing candidates
|
||||||
|
|
||||||
enum class token_healing_type : uint8_t {
|
|
||||||
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
|
||||||
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
|
||||||
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
|
||||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
|
||||||
};
|
|
||||||
|
|
||||||
struct token_healing_context {
|
struct token_healing_context {
|
||||||
std::string prefix; // remaining prefix to generate (the input prompt's suffix)
|
std::string prefix; // remaining prefix to generate (the input prompt's suffix)
|
||||||
|
|
||||||
|
@ -44,8 +37,8 @@ static std::vector<llama_token> token_healing_find_prefix(
|
||||||
std::vector<llama_token> candidates;
|
std::vector<llama_token> candidates;
|
||||||
const auto & vocab = th_ctx->vocab;
|
const auto & vocab = th_ctx->vocab;
|
||||||
for (size_t token_id = 0; token_id < vocab.size(); ++token_id) {
|
for (size_t token_id = 0; token_id < vocab.size(); ++token_id) {
|
||||||
if (startswith(vocab[token_id], prefix)
|
if (startswith(vocab[token_id], prefix) ||
|
||||||
|| (include_partial_prefix && startswith(prefix, vocab[token_id]))) {
|
(include_partial_prefix && startswith(prefix, vocab[token_id]))) {
|
||||||
candidates.push_back((llama_token)token_id);
|
candidates.push_back((llama_token)token_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,14 +64,14 @@ static void token_healing_free(token_healing_context * th_ctx) {
|
||||||
static int token_healing_heal(
|
static int token_healing_heal(
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
std::vector<llama_token> & tokens_list,
|
std::vector<llama_token> & tokens_list,
|
||||||
const token_healing_type th_type,
|
const llama_token_healing_type th_type,
|
||||||
token_healing_context * th_ctx,
|
token_healing_context * th_ctx,
|
||||||
int n_rollback = 1) {
|
int n_rollback = 1) {
|
||||||
if (tokens_list.empty()) {
|
if (tokens_list.empty()) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
const bool is_dynamic = th_type == token_healing_type::DYNAMIC_ONCE || th_type == token_healing_type::DYNAMIC_MULTI;
|
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
const int n_ctx = tokens_list.size();
|
const int n_ctx = tokens_list.size();
|
||||||
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
||||||
int n_removed = 0;
|
int n_removed = 0;
|
||||||
|
@ -104,7 +97,7 @@ static int token_healing_heal(
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||||
const bool is_multi_decoding = th_type == token_healing_type::DYNAMIC_MULTI || th_type == token_healing_type::ROLLBACK_MULTI;
|
const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
|
const std::vector<llama_token> candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding);
|
||||||
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
||||||
if (n_removed == 1 && candidates.size() == 1) {
|
if (n_removed == 1 && candidates.size() == 1) {
|
||||||
|
@ -119,9 +112,7 @@ static int token_healing_heal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (int i = 0; i < n_removed; ++i) {
|
tokens_list.resize(n_ctx - n_removed);
|
||||||
tokens_list.pop_back();
|
|
||||||
}
|
|
||||||
if (tokens_list.empty()) {
|
if (tokens_list.empty()) {
|
||||||
// If the first token was removed, llama_decode would crash with an empty sequence, so add bos.
|
// If the first token was removed, llama_decode would crash with an empty sequence, so add bos.
|
||||||
tokens_list.emplace_back(llama_token_bos(model));
|
tokens_list.emplace_back(llama_token_bos(model));
|
||||||
|
@ -146,16 +137,16 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool token_healing_enabled = true;
|
bool token_healing_enabled = true;
|
||||||
auto th_type = token_healing_type::DYNAMIC_MULTI;
|
auto th_type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
int th_n_rollback = 1;
|
int th_n_rollback = 1;
|
||||||
if (argc >= 4) {
|
if (argc >= 4) {
|
||||||
std::string value(argv[3]);
|
std::string value(argv[3]);
|
||||||
/**/ if (value == "0" ) { token_healing_enabled = false; }
|
/**/ if (value == "0" ) { token_healing_enabled = false; }
|
||||||
else if (value == "1" ) { th_type = token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
||||||
else if (value == "d1") { th_type = token_healing_type::DYNAMIC_ONCE; }
|
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
else if (value == "d" ) { th_type = token_healing_type::DYNAMIC_MULTI; }
|
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
else if (value[0] == 'r' ) {
|
else if (value[0] == 'r' ) {
|
||||||
th_type = token_healing_type::ROLLBACK_MULTI;
|
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
th_n_rollback = std::stoi(value.substr(1));
|
th_n_rollback = std::stoi(value.substr(1));
|
||||||
if (th_n_rollback <= 0) {
|
if (th_n_rollback <= 0) {
|
||||||
token_healing_enabled = false;
|
token_healing_enabled = false;
|
||||||
|
@ -281,7 +272,7 @@ int main(int argc, char ** argv) {
|
||||||
// Constrain tokens based on the remaining token healing prefix
|
// Constrain tokens based on the remaining token healing prefix
|
||||||
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||||
std::vector<llama_token> th_candidates;
|
std::vector<llama_token> th_candidates;
|
||||||
if (th_type == token_healing_type::ROLLBACK_LAST || th_type == token_healing_type::DYNAMIC_ONCE) {
|
if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) {
|
||||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
|
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false);
|
||||||
} else {
|
} else {
|
||||||
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);
|
th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue