Merge b27f87d6da
into 8f1d81a0b6
This commit is contained in:
commit
71086fefba
7 changed files with 384 additions and 14 deletions
|
@ -1435,6 +1435,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-th" || arg == "--token-healing") {
|
||||||
|
CHECK_ARG
|
||||||
|
std::string value(argv[i]);
|
||||||
|
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--override-kv") {
|
if (arg == "--override-kv") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
|
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
|
||||||
|
@ -1872,6 +1878,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||||
"if suffix/prefix are specified, template will be disabled\n"
|
"if suffix/prefix are specified, template will be disabled\n"
|
||||||
"only commonly used templates are accepted:\n"
|
"only commonly used templates are accepted:\n"
|
||||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
||||||
|
|
||||||
|
options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
|
||||||
|
"Token healing type. (default: 0, disabled)\n"
|
||||||
|
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
|
||||||
options.push_back({ "grammar" });
|
options.push_back({ "grammar" });
|
||||||
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
|
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
|
||||||
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
|
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
|
||||||
|
|
|
@ -2,6 +2,181 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing (internal)
|
||||||
|
//
|
||||||
|
|
||||||
|
static bool startswith(const std::string & str, const std::string & prefix) {
|
||||||
|
return str.rfind(prefix, 0) != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
if (startswith(token, prefix)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_token> token_healing_get_candidates(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
const std::string & prefix,
|
||||||
|
const bool include_partial_prefix) {
|
||||||
|
// Example: prefix=" world" -> " world", " worldwide", ...
|
||||||
|
// If `include_partial_prefix`, include also: " w", " wo", ...
|
||||||
|
std::vector<llama_token> candidates;
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
if (startswith(token, prefix) ||
|
||||||
|
(include_partial_prefix && startswith(prefix, token))) {
|
||||||
|
candidates.push_back(token_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return candidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t get_max_token_length(const llama_context * ctx_main) {
|
||||||
|
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
size_t len = 0;
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||||
|
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||||
|
len = std::max(len, token.size());
|
||||||
|
}
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_token_healing_output llama_token_healing_get_prefix(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
const llama_token_healing_type th_type,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
int max_to_remove) {
|
||||||
|
if (tokens.size() <= 1) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_ctx = tokens.size();
|
||||||
|
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||||
|
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
||||||
|
|
||||||
|
int removed = 0;
|
||||||
|
std::string prefix;
|
||||||
|
|
||||||
|
const llama_model * model = llama_get_model(ctx_main);
|
||||||
|
auto is_special_token = [&](const llama_token token_id) {
|
||||||
|
return llama_token_is_control(model, token_id)
|
||||||
|
|| llama_token_bos (model) == token_id
|
||||||
|
|| llama_token_eos (model) == token_id
|
||||||
|
|| llama_token_cls (model) == token_id
|
||||||
|
|| llama_token_sep (model) == token_id
|
||||||
|
|| llama_token_pad (model) == token_id
|
||||||
|
|| llama_token_prefix (model) == token_id
|
||||||
|
|| llama_token_middle (model) == token_id
|
||||||
|
|| llama_token_suffix (model) == token_id
|
||||||
|
|| llama_token_eot (model) == token_id;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
|
||||||
|
// The number of bytes to roll back cannot exceed the length of the longest token.
|
||||||
|
const size_t n_longest_token = get_max_token_length(ctx_main);
|
||||||
|
size_t len = 0;
|
||||||
|
while (removed < max_to_remove) {
|
||||||
|
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||||
|
if (is_special_token(next_token_id)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
|
||||||
|
if (len + next_token_size > n_longest_token) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
len += next_token_size;
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (removed > 0) {
|
||||||
|
prefix.clear();
|
||||||
|
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||||
|
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||||
|
}
|
||||||
|
if (token_healing_prefix_exists(ctx_main, prefix)) {
|
||||||
|
break; // Stop on longest valid prefix
|
||||||
|
}
|
||||||
|
removed -= 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Roll back tokens a fixed amount and stop early if a special token is encountered.
|
||||||
|
while (removed < max_to_remove) {
|
||||||
|
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||||
|
if (is_special_token(next_token_id)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||||
|
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {prefix, removed};
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing (external)
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_token_healing_output llama_token_healing_rollback(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
std::vector<llama_token> & tokens,
|
||||||
|
llama_token_healing_type th_type,
|
||||||
|
int max_to_remove) {
|
||||||
|
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||||
|
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||||
|
llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||||
|
|
||||||
|
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
|
||||||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step);
|
||||||
|
LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed);
|
||||||
|
if (out.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||||
|
LOG("token_healing: nothing to heal\n");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, trim prompt tokens
|
||||||
|
tokens.resize(tokens.size() - out.n_tokens_removed);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||||
|
ctx_sampling->token_healing_prefix = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
|
||||||
|
th_params.enabled = true;
|
||||||
|
th_params.n_rollback = -1;
|
||||||
|
/**/ if (params == "0" ) { th_params.enabled = false; }
|
||||||
|
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||||
|
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
|
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
|
else if (params[0] == 'r' ) {
|
||||||
|
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
|
th_params.n_rollback = std::stoi(params.substr(1));
|
||||||
|
if (th_params.n_rollback <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Sampling
|
||||||
|
//
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
|
@ -72,6 +247,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
ctx->grammar = grammar;
|
ctx->grammar = grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx->token_healing_prefix.clear();
|
||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||||
ctx->cur.clear();
|
ctx->cur.clear();
|
||||||
ctx->n_valid = 0;
|
ctx->n_valid = 0;
|
||||||
|
@ -130,7 +307,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "(Token healing) -> CFG -> Penalties ";
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (auto sampler_type : params.samplers_sequence) {
|
for (auto sampler_type : params.samplers_sequence) {
|
||||||
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
|
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
|
||||||
|
@ -391,10 +568,28 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
cur.resize(n_vocab);
|
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||||
|
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
|
if (params.token_healing.enabled && !th_prefix.empty()) {
|
||||||
|
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
|
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
for (const llama_token token_id : th_candidates) {
|
||||||
|
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||||
|
cur.clear();
|
||||||
|
for (const llama_token token_id : th_candidates) {
|
||||||
|
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cur.resize(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
@ -457,4 +652,19 @@ void llama_sampling_accept(
|
||||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
|
||||||
|
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
|
if (!th_prefix.empty()) {
|
||||||
|
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||||
|
if (new_token_piece.size() < th_prefix.size()) {
|
||||||
|
// Shift prefix constraint (for multi step token healing)
|
||||||
|
th_prefix = th_prefix.substr(new_token_piece.size());
|
||||||
|
} else {
|
||||||
|
// Prefix has been generated => no more constrained generation
|
||||||
|
th_prefix.clear();
|
||||||
|
LOG("token_healing: done\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,19 @@ enum class llama_sampler_type : char {
|
||||||
TEMPERATURE = 't'
|
TEMPERATURE = 't'
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class llama_token_healing_type : uint8_t {
|
||||||
|
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||||
|
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||||
|
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||||
|
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_token_healing_params {
|
||||||
|
bool enabled = false;
|
||||||
|
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
int n_rollback = -1; // number of tokens to roll back
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
@ -62,6 +75,8 @@ typedef struct llama_sampling_params {
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
|
|
||||||
|
llama_token_healing_params token_healing;
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -78,6 +93,8 @@ struct llama_sampling_context {
|
||||||
// internal
|
// internal
|
||||||
grammar_parser::parse_state parsed_grammar;
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
|
||||||
|
std::string token_healing_prefix; // remaining prefix to constrain sampling
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> prev;
|
std::vector<llama_token> prev;
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
@ -158,3 +175,25 @@ void llama_sampling_accept(
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar);
|
bool apply_grammar);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token healing
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_token_healing_output {
|
||||||
|
std::string prefix;
|
||||||
|
int n_tokens_removed;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Roll back `tokens` for constrained generation according to the token healing strategy.
|
||||||
|
// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling.
|
||||||
|
llama_token_healing_output llama_token_healing_rollback(
|
||||||
|
const llama_context * ctx_main,
|
||||||
|
std::vector<llama_token> & tokens,
|
||||||
|
llama_token_healing_type th_type,
|
||||||
|
int max_to_remove = -1);
|
||||||
|
|
||||||
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||||
|
|
||||||
|
// Helper for parsing token healing params from a string.
|
||||||
|
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);
|
||||||
|
|
|
@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
|
||||||
|
|
||||||
Example usage: `--logit-bias 29905-inf`
|
Example usage: `--logit-bias 29905-inf`
|
||||||
|
|
||||||
|
### Token healing
|
||||||
|
|
||||||
|
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
|
||||||
|
|
||||||
|
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
|
||||||
|
|
||||||
|
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
|
||||||
|
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
|
||||||
|
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
|
||||||
|
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
|
||||||
|
|
||||||
|
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
|
||||||
|
|
||||||
### RNG Seed
|
### RNG Seed
|
||||||
|
|
||||||
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
|
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
|
||||||
|
|
|
@ -325,6 +325,16 @@ int main(int argc, char ** argv) {
|
||||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||||
|
sparams.token_healing.enabled = false;
|
||||||
|
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||||
|
}
|
||||||
|
llama_token_healing_output token_healing_out{};
|
||||||
|
if (!params.interactive_first && sparams.token_healing.enabled) {
|
||||||
|
token_healing_out = llama_token_healing_rollback(ctx, embd_inp,
|
||||||
|
sparams.token_healing.type, sparams.token_healing.n_rollback);
|
||||||
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
if (embd_inp.empty()) {
|
if (embd_inp.empty()) {
|
||||||
if (add_bos) {
|
if (add_bos) {
|
||||||
|
@ -349,7 +359,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||||
|
|
||||||
original_prompt_len = original_inp.size();
|
original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed;
|
||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||||
|
@ -544,6 +554,7 @@ int main(int argc, char ** argv) {
|
||||||
int n_consumed = 0;
|
int n_consumed = 0;
|
||||||
int n_session_consumed = 0;
|
int n_session_consumed = 0;
|
||||||
int n_past_guidance = 0;
|
int n_past_guidance = 0;
|
||||||
|
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
|
||||||
|
|
||||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||||
|
@ -570,6 +581,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||||
|
|
||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
|
@ -804,7 +816,15 @@ int main(int argc, char ** argv) {
|
||||||
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
|
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
|
||||||
|
|
||||||
// Console/Stream Output
|
// Console/Stream Output
|
||||||
fprintf(stdout, "%s", token_str.c_str());
|
// Suppress printing while generating token healing prefix
|
||||||
|
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
|
||||||
|
fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str());
|
||||||
|
n_bytes_to_skip = 0;
|
||||||
|
} else if (n_bytes_to_skip > 0) {
|
||||||
|
n_bytes_to_skip -= token_str.size();
|
||||||
|
} else {
|
||||||
|
fprintf(stdout, "%s", token_str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Record Displayed Tokens To Log
|
// Record Displayed Tokens To Log
|
||||||
// Note: Generated tokens are created one by one hence this check
|
// Note: Generated tokens are created one by one hence this check
|
||||||
|
@ -896,6 +916,8 @@ int main(int argc, char ** argv) {
|
||||||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
token_healing_out = {};
|
||||||
|
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
LOG("waiting for user input\n");
|
LOG("waiting for user input\n");
|
||||||
|
|
||||||
|
@ -968,6 +990,16 @@ int main(int argc, char ** argv) {
|
||||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||||
|
|
||||||
|
if (sparams.token_healing.enabled) {
|
||||||
|
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
||||||
|
const int n_new_tokens = embd_inp.size() - original_size;
|
||||||
|
const int max_to_remove = sparams.token_healing.n_rollback < 0
|
||||||
|
? n_new_tokens
|
||||||
|
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
|
||||||
|
token_healing_out = llama_token_healing_rollback(ctx, embd_inp, sparams.token_healing.type, max_to_remove);
|
||||||
|
n_bytes_to_skip = token_healing_out.prefix.size();
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
const llama_token token = embd_inp[i];
|
const llama_token token = embd_inp[i];
|
||||||
output_tokens.push_back(token);
|
output_tokens.push_back(token);
|
||||||
|
@ -977,7 +1009,7 @@ int main(int argc, char ** argv) {
|
||||||
// reset assistant message
|
// reset assistant message
|
||||||
assistant_ss.str("");
|
assistant_ss.str("");
|
||||||
|
|
||||||
n_remain -= line_inp.size();
|
n_remain -= line_inp.size() + token_healing_out.n_tokens_removed;
|
||||||
LOG("n_remain: %d\n", n_remain);
|
LOG("n_remain: %d\n", n_remain);
|
||||||
} else {
|
} else {
|
||||||
LOG("empty line, passing control back\n");
|
LOG("empty line, passing control back\n");
|
||||||
|
@ -989,6 +1021,10 @@ int main(int argc, char ** argv) {
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
llama_sampling_reset(ctx_sampling);
|
llama_sampling_reset(ctx_sampling);
|
||||||
|
if (token_healing_out.n_tokens_removed > 0) {
|
||||||
|
// Set new prefix after an interaction
|
||||||
|
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -482,6 +482,8 @@ node index.js
|
||||||
|
|
||||||
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
||||||
|
|
||||||
|
`token_healing`: Set the token healing strategy. Default: `0`, which is disabled. Possible values: `1` to replace one token, `d1` to replace the longest suffix with a single token, `d` to replace the longest suffix, `rN` to roll back N tokens (e.g. `r3`). See [here](../main/README.md#token-healing) for more details.
|
||||||
|
|
||||||
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
||||||
|
|
||||||
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
||||||
|
|
|
@ -186,6 +186,7 @@ struct server_slot {
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
size_t n_sent_token_probs = 0;
|
size_t n_sent_token_probs = 0;
|
||||||
|
size_t n_th_prefix = 0; // size of remaining token healing prefix
|
||||||
|
|
||||||
int64_t t_start_process_prompt;
|
int64_t t_start_process_prompt;
|
||||||
int64_t t_start_generation;
|
int64_t t_start_generation;
|
||||||
|
@ -207,6 +208,7 @@ struct server_slot {
|
||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
n_th_prefix = 0;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -1096,6 +1098,25 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto & token_healing_str = data.find("token_healing");
|
||||||
|
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||||
|
const auto value = token_healing_str->get<std::string>();
|
||||||
|
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
|
||||||
|
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LOG_VERBOSE("token healing", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"enabled", slot.sparams.token_healing.enabled},
|
||||||
|
{"type", slot.sparams.token_healing.type},
|
||||||
|
{"n_rollback", slot.sparams.token_healing.n_rollback}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
slot.sparams.token_healing = default_sparams.token_healing;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
if (slot.ctx_sampling != nullptr) {
|
if (slot.ctx_sampling != nullptr) {
|
||||||
llama_sampling_free(slot.ctx_sampling);
|
llama_sampling_free(slot.ctx_sampling);
|
||||||
|
@ -1182,14 +1203,26 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
// search stop word and delete it
|
|
||||||
slot.generated_text += token_str;
|
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
|
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
|
||||||
|
bool is_token_healing = false;
|
||||||
|
if (slot.n_th_prefix > 0) {
|
||||||
|
if (slot.n_th_prefix < token_str.size()) {
|
||||||
|
slot.generated_text += token_str.substr(slot.n_th_prefix);
|
||||||
|
slot.n_th_prefix = 0;
|
||||||
|
is_token_healing = false; // to send partial token text when streaming
|
||||||
|
} else {
|
||||||
|
slot.n_th_prefix -= token_str.size();
|
||||||
|
is_token_healing = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slot.generated_text += token_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
||||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||||
|
@ -1217,7 +1250,7 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!incomplete) {
|
if (!incomplete && !is_token_healing) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
@ -1249,7 +1282,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (incomplete) {
|
if (incomplete || is_token_healing) {
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1354,7 +1387,8 @@ struct server_context {
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
{"samplers", samplers_sequence}
|
{"samplers", samplers_sequence},
|
||||||
|
{"token_healing_enabled", slot.sparams.token_healing.enabled}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2038,6 +2072,8 @@ struct server_context {
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
|
llama_token_healing_output token_healing_out{};
|
||||||
|
|
||||||
if (slot.infill) {
|
if (slot.infill) {
|
||||||
const bool add_bos = llama_add_bos_token(model);
|
const bool add_bos = llama_add_bos_token(model);
|
||||||
bool suff_rm_leading_spc = true;
|
bool suff_rm_leading_spc = true;
|
||||||
|
@ -2057,6 +2093,12 @@ struct server_context {
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
||||||
|
|
||||||
|
if (slot.sparams.token_healing.enabled) {
|
||||||
|
// For FIM roll back only the prefix part (i.e. cursor location)
|
||||||
|
token_healing_out = llama_token_healing_rollback(ctx, prefix_tokens,
|
||||||
|
slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback);
|
||||||
|
}
|
||||||
|
|
||||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||||
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
||||||
if (add_bos) {
|
if (add_bos) {
|
||||||
|
@ -2072,6 +2114,11 @@ struct server_context {
|
||||||
prompt_tokens = embd_inp;
|
prompt_tokens = embd_inp;
|
||||||
} else {
|
} else {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
|
|
||||||
|
if (slot.sparams.token_healing.enabled) {
|
||||||
|
token_healing_out = llama_token_healing_rollback(ctx, prompt_tokens,
|
||||||
|
slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
@ -2086,6 +2133,16 @@ struct server_context {
|
||||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (slot.sparams.token_healing.enabled) {
|
||||||
|
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||||
|
LOG_VERBOSE("token healing prompt", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task},
|
||||||
|
{"removed_suffix", token_healing_out.prefix},
|
||||||
|
{"n_tokens_removed", token_healing_out.n_tokens_removed}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// empty prompt passed -> release the slot and send empty response
|
// empty prompt passed -> release the slot and send empty response
|
||||||
if (prompt_tokens.empty()) {
|
if (prompt_tokens.empty()) {
|
||||||
LOG_INFO("empty prompt - releasing slot", {
|
LOG_INFO("empty prompt - releasing slot", {
|
||||||
|
@ -2151,6 +2208,9 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
if (slot.sparams.token_healing.enabled) {
|
||||||
|
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||||
|
}
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
if (!slot.params.cache_prompt) {
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue