Merge b27f87d6da
into 8f1d81a0b6
This commit is contained in:
commit
71086fefba
7 changed files with 384 additions and 14 deletions
|
@ -1435,6 +1435,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
|
||||
return true;
|
||||
}
|
||||
if (arg == "-th" || arg == "--token-healing") {
|
||||
CHECK_ARG
|
||||
std::string value(argv[i]);
|
||||
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--override-kv") {
|
||||
CHECK_ARG
|
||||
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
|
||||
|
@ -1872,6 +1878,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted:\n"
|
||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
||||
|
||||
options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
|
||||
"Token healing type. (default: 0, disabled)\n"
|
||||
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
|
||||
options.push_back({ "grammar" });
|
||||
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
|
||||
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
|
||||
|
|
|
@ -2,6 +2,181 @@
|
|||
#include "sampling.h"
|
||||
#include <random>
|
||||
|
||||
//
|
||||
// Token healing (internal)
|
||||
//
|
||||
|
||||
static bool startswith(const std::string & str, const std::string & prefix) {
|
||||
return str.rfind(prefix, 0) != std::string::npos;
|
||||
}
|
||||
|
||||
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
if (startswith(token, prefix)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<llama_token> token_healing_get_candidates(
|
||||
const llama_context * ctx_main,
|
||||
const std::string & prefix,
|
||||
const bool include_partial_prefix) {
|
||||
// Example: prefix=" world" -> " world", " worldwide", ...
|
||||
// If `include_partial_prefix`, include also: " w", " wo", ...
|
||||
std::vector<llama_token> candidates;
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
if (startswith(token, prefix) ||
|
||||
(include_partial_prefix && startswith(prefix, token))) {
|
||||
candidates.push_back(token_id);
|
||||
}
|
||||
}
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static size_t get_max_token_length(const llama_context * ctx_main) {
|
||||
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
size_t len = 0;
|
||||
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
|
||||
std::string token = llama_token_to_piece(ctx_main, token_id);
|
||||
len = std::max(len, token.size());
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
static llama_token_healing_output llama_token_healing_get_prefix(
|
||||
const llama_context * ctx_main,
|
||||
const llama_token_healing_type th_type,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int max_to_remove) {
|
||||
if (tokens.size() <= 1) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const int n_ctx = tokens.size();
|
||||
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
|
||||
|
||||
int removed = 0;
|
||||
std::string prefix;
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_main);
|
||||
auto is_special_token = [&](const llama_token token_id) {
|
||||
return llama_token_is_control(model, token_id)
|
||||
|| llama_token_bos (model) == token_id
|
||||
|| llama_token_eos (model) == token_id
|
||||
|| llama_token_cls (model) == token_id
|
||||
|| llama_token_sep (model) == token_id
|
||||
|| llama_token_pad (model) == token_id
|
||||
|| llama_token_prefix (model) == token_id
|
||||
|| llama_token_middle (model) == token_id
|
||||
|| llama_token_suffix (model) == token_id
|
||||
|| llama_token_eot (model) == token_id;
|
||||
};
|
||||
|
||||
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
|
||||
// The number of bytes to roll back cannot exceed the length of the longest token.
|
||||
const size_t n_longest_token = get_max_token_length(ctx_main);
|
||||
size_t len = 0;
|
||||
while (removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||
if (is_special_token(next_token_id)) {
|
||||
break;
|
||||
}
|
||||
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
|
||||
if (len + next_token_size > n_longest_token) {
|
||||
break;
|
||||
}
|
||||
len += next_token_size;
|
||||
removed += 1;
|
||||
}
|
||||
|
||||
while (removed > 0) {
|
||||
prefix.clear();
|
||||
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||
}
|
||||
if (token_healing_prefix_exists(ctx_main, prefix)) {
|
||||
break; // Stop on longest valid prefix
|
||||
}
|
||||
removed -= 1;
|
||||
}
|
||||
} else {
|
||||
// Roll back tokens a fixed amount and stop early if a special token is encountered.
|
||||
while (removed < max_to_remove) {
|
||||
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||
if (is_special_token(next_token_id)) {
|
||||
break;
|
||||
}
|
||||
removed += 1;
|
||||
}
|
||||
for (int i = n_ctx - removed; i < n_ctx; i++) {
|
||||
prefix += llama_token_to_piece(ctx_main, tokens[i]);
|
||||
}
|
||||
}
|
||||
return {prefix, removed};
|
||||
}
|
||||
|
||||
//
|
||||
// Token healing (external)
|
||||
//
|
||||
|
||||
llama_token_healing_output llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
std::vector<llama_token> & tokens,
|
||||
llama_token_healing_type th_type,
|
||||
int max_to_remove) {
|
||||
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||
llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed);
|
||||
if (out.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||
LOG("token_healing: nothing to heal\n");
|
||||
return {};
|
||||
}
|
||||
|
||||
// Finally, trim prompt tokens
|
||||
tokens.resize(tokens.size() - out.n_tokens_removed);
|
||||
return out;
|
||||
}
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||
ctx_sampling->token_healing_prefix = prefix;
|
||||
}
|
||||
|
||||
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
|
||||
th_params.enabled = true;
|
||||
th_params.n_rollback = -1;
|
||||
/**/ if (params == "0" ) { th_params.enabled = false; }
|
||||
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (params[0] == 'r' ) {
|
||||
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_params.n_rollback = std::stoi(params.substr(1));
|
||||
if (th_params.n_rollback <= 0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// Sampling
|
||||
//
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
|
@ -72,6 +247,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
|||
ctx->grammar = grammar;
|
||||
}
|
||||
|
||||
ctx->token_healing_prefix.clear();
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_valid = 0;
|
||||
|
@ -130,7 +307,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||
}
|
||||
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
std::string result = "(Token healing) -> CFG -> Penalties ";
|
||||
if (params.mirostat == 0) {
|
||||
for (auto sampler_type : params.samplers_sequence) {
|
||||
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
|
||||
|
@ -391,10 +568,28 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
cur.resize(n_vocab);
|
||||
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (params.token_healing.enabled && !th_prefix.empty()) {
|
||||
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||
for (const llama_token token_id : th_candidates) {
|
||||
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
|
||||
}
|
||||
|
||||
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
|
||||
cur.clear();
|
||||
for (const llama_token token_id : th_candidates) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
} else {
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
@ -457,4 +652,19 @@ void llama_sampling_accept(
|
|||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
}
|
||||
|
||||
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
|
||||
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (!th_prefix.empty()) {
|
||||
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||
if (new_token_piece.size() < th_prefix.size()) {
|
||||
// Shift prefix constraint (for multi step token healing)
|
||||
th_prefix = th_prefix.substr(new_token_piece.size());
|
||||
} else {
|
||||
// Prefix has been generated => no more constrained generation
|
||||
th_prefix.clear();
|
||||
LOG("token_healing: done\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,19 @@ enum class llama_sampler_type : char {
|
|||
TEMPERATURE = 't'
|
||||
};
|
||||
|
||||
enum class llama_token_healing_type : uint8_t {
|
||||
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
|
||||
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
|
||||
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
|
||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||
};
|
||||
|
||||
struct llama_token_healing_params {
|
||||
bool enabled = false;
|
||||
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||
int n_rollback = -1; // number of tokens to roll back
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
|
@ -62,6 +75,8 @@ typedef struct llama_sampling_params {
|
|||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
|
||||
llama_token_healing_params token_healing;
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
|
@ -78,6 +93,8 @@ struct llama_sampling_context {
|
|||
// internal
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
std::string token_healing_prefix; // remaining prefix to constrain sampling
|
||||
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
|
@ -158,3 +175,25 @@ void llama_sampling_accept(
|
|||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
//
|
||||
// Token healing
|
||||
//
|
||||
|
||||
struct llama_token_healing_output {
|
||||
std::string prefix;
|
||||
int n_tokens_removed;
|
||||
};
|
||||
|
||||
// Roll back `tokens` for constrained generation according to the token healing strategy.
|
||||
// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling.
|
||||
llama_token_healing_output llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
std::vector<llama_token> & tokens,
|
||||
llama_token_healing_type th_type,
|
||||
int max_to_remove = -1);
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||
|
||||
// Helper for parsing token healing params from a string.
|
||||
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);
|
||||
|
|
|
@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
|
|||
|
||||
Example usage: `--logit-bias 29905-inf`
|
||||
|
||||
### Token healing
|
||||
|
||||
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
|
||||
|
||||
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
|
||||
|
||||
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
|
||||
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
|
||||
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
|
||||
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
|
||||
|
||||
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
|
||||
|
||||
### RNG Seed
|
||||
|
||||
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
|
||||
|
|
|
@ -325,6 +325,16 @@ int main(int argc, char ** argv) {
|
|||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||
sparams.token_healing.enabled = false;
|
||||
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||
}
|
||||
llama_token_healing_output token_healing_out{};
|
||||
if (!params.interactive_first && sparams.token_healing.enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, embd_inp,
|
||||
sparams.token_healing.type, sparams.token_healing.n_rollback);
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
if (add_bos) {
|
||||
|
@ -349,7 +359,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||
|
||||
original_prompt_len = original_inp.size();
|
||||
original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed;
|
||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||
|
@ -544,6 +554,7 @@ int main(int argc, char ** argv) {
|
|||
int n_consumed = 0;
|
||||
int n_session_consumed = 0;
|
||||
int n_past_guidance = 0;
|
||||
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
|
||||
|
||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||
|
@ -570,6 +581,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
|
@ -804,7 +816,15 @@ int main(int argc, char ** argv) {
|
|||
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
|
||||
|
||||
// Console/Stream Output
|
||||
fprintf(stdout, "%s", token_str.c_str());
|
||||
// Suppress printing while generating token healing prefix
|
||||
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
|
||||
fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str());
|
||||
n_bytes_to_skip = 0;
|
||||
} else if (n_bytes_to_skip > 0) {
|
||||
n_bytes_to_skip -= token_str.size();
|
||||
} else {
|
||||
fprintf(stdout, "%s", token_str.c_str());
|
||||
}
|
||||
|
||||
// Record Displayed Tokens To Log
|
||||
// Note: Generated tokens are created one by one hence this check
|
||||
|
@ -896,6 +916,8 @@ int main(int argc, char ** argv) {
|
|||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
token_healing_out = {};
|
||||
|
||||
if (n_past > 0 && is_interacting) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
|
@ -968,6 +990,16 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||
|
||||
if (sparams.token_healing.enabled) {
|
||||
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
||||
const int n_new_tokens = embd_inp.size() - original_size;
|
||||
const int max_to_remove = sparams.token_healing.n_rollback < 0
|
||||
? n_new_tokens
|
||||
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
|
||||
token_healing_out = llama_token_healing_rollback(ctx, embd_inp, sparams.token_healing.type, max_to_remove);
|
||||
n_bytes_to_skip = token_healing_out.prefix.size();
|
||||
}
|
||||
|
||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||
const llama_token token = embd_inp[i];
|
||||
output_tokens.push_back(token);
|
||||
|
@ -977,7 +1009,7 @@ int main(int argc, char ** argv) {
|
|||
// reset assistant message
|
||||
assistant_ss.str("");
|
||||
|
||||
n_remain -= line_inp.size();
|
||||
n_remain -= line_inp.size() + token_healing_out.n_tokens_removed;
|
||||
LOG("n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
LOG("empty line, passing control back\n");
|
||||
|
@ -989,6 +1021,10 @@ int main(int argc, char ** argv) {
|
|||
if (n_past > 0) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
if (token_healing_out.n_tokens_removed > 0) {
|
||||
// Set new prefix after an interaction
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||
}
|
||||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
|
|
|
@ -482,6 +482,8 @@ node index.js
|
|||
|
||||
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
||||
|
||||
`token_healing`: Set the token healing strategy. Default: `0`, which is disabled. Possible values: `1` to replace one token, `d1` to replace the longest suffix with a single token, `d` to replace the longest suffix, `rN` to roll back N tokens (e.g. `r3`). See [here](../main/README.md#token-healing) for more details.
|
||||
|
||||
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
||||
|
||||
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
||||
|
|
|
@ -186,6 +186,7 @@ struct server_slot {
|
|||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
size_t n_sent_token_probs = 0;
|
||||
size_t n_th_prefix = 0; // size of remaining token healing prefix
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
|
@ -207,6 +208,7 @@ struct server_slot {
|
|||
infill = false;
|
||||
ga_i = 0;
|
||||
n_past_se = 0;
|
||||
n_th_prefix = 0;
|
||||
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
@ -1096,6 +1098,25 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto & token_healing_str = data.find("token_healing");
|
||||
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||
const auto value = token_healing_str->get<std::string>();
|
||||
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
|
||||
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
LOG_VERBOSE("token healing", {
|
||||
{"id_slot", slot.id},
|
||||
{"enabled", slot.sparams.token_healing.enabled},
|
||||
{"type", slot.sparams.token_healing.type},
|
||||
{"n_rollback", slot.sparams.token_healing.n_rollback}
|
||||
});
|
||||
} else {
|
||||
slot.sparams.token_healing = default_sparams.token_healing;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
|
@ -1182,14 +1203,26 @@ struct server_context {
|
|||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
||||
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
|
||||
bool is_token_healing = false;
|
||||
if (slot.n_th_prefix > 0) {
|
||||
if (slot.n_th_prefix < token_str.size()) {
|
||||
slot.generated_text += token_str.substr(slot.n_th_prefix);
|
||||
slot.n_th_prefix = 0;
|
||||
is_token_healing = false; // to send partial token text when streaming
|
||||
} else {
|
||||
slot.n_th_prefix -= token_str.size();
|
||||
is_token_healing = true;
|
||||
}
|
||||
} else {
|
||||
slot.generated_text += token_str;
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||
|
@ -1217,7 +1250,7 @@ struct server_context {
|
|||
break;
|
||||
}
|
||||
|
||||
if (!incomplete) {
|
||||
if (!incomplete && !is_token_healing) {
|
||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
|
@ -1249,7 +1282,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
if (incomplete) {
|
||||
if (incomplete || is_token_healing) {
|
||||
slot.has_next_token = true;
|
||||
}
|
||||
|
||||
|
@ -1354,7 +1387,8 @@ struct server_context {
|
|||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
{"samplers", samplers_sequence},
|
||||
{"token_healing_enabled", slot.sparams.token_healing.enabled}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2038,6 +2072,8 @@ struct server_context {
|
|||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
llama_token_healing_output token_healing_out{};
|
||||
|
||||
if (slot.infill) {
|
||||
const bool add_bos = llama_add_bos_token(model);
|
||||
bool suff_rm_leading_spc = true;
|
||||
|
@ -2057,6 +2093,12 @@ struct server_context {
|
|||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
||||
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
// For FIM roll back only the prefix part (i.e. cursor location)
|
||||
token_healing_out = llama_token_healing_rollback(ctx, prefix_tokens,
|
||||
slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback);
|
||||
}
|
||||
|
||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
||||
if (add_bos) {
|
||||
|
@ -2072,6 +2114,11 @@ struct server_context {
|
|||
prompt_tokens = embd_inp;
|
||||
} else {
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, prompt_tokens,
|
||||
slot.sparams.token_healing.type, slot.sparams.token_healing.n_rollback);
|
||||
}
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
|
@ -2086,6 +2133,16 @@ struct server_context {
|
|||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||
LOG_VERBOSE("token healing prompt", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"removed_suffix", token_healing_out.prefix},
|
||||
{"n_tokens_removed", token_healing_out.n_tokens_removed}
|
||||
});
|
||||
}
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
LOG_INFO("empty prompt - releasing slot", {
|
||||
|
@ -2151,6 +2208,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||
}
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue