server : add token healing support

This commit is contained in:
mare5x 2024-06-26 17:12:57 +02:00
parent fc8773d309
commit d5eea13797
2 changed files with 72 additions and 7 deletions

View file

@ -436,6 +436,8 @@ node index.js
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
`token_healing`: Set token healing strategy. Default: `0`, which is disabled.
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`

View file

@ -185,6 +185,7 @@ struct server_slot {
// stats
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;
size_t n_th_prefix = 0; // size of remaining token healing prefix
int64_t t_start_process_prompt;
int64_t t_start_generation;
@ -206,6 +207,7 @@ struct server_slot {
infill = false;
ga_i = 0;
n_past_se = 0;
n_th_prefix = 0;
generated_token_probs.clear();
}
@ -1094,6 +1096,36 @@ struct server_context {
}
}
{
const auto & token_healing_str = data.find("token_healing");
auto & th_enabled = slot.sparams.token_healing_enabled;
th_enabled = default_sparams.token_healing_enabled;
if (token_healing_str != data.end() && token_healing_str->is_string()) {
const auto value = token_healing_str->get<std::string>();
auto & th_type = slot.sparams.token_healing_type;
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
th_enabled = true;
/**/ if (value == "0" ) { th_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
th_enabled = false;
}
} else { th_enabled = false; }
LOG_VERBOSE("token healing", {
{"id_slot", slot.id},
{"enabled", th_enabled},
{"type", th_type},
{"n_rollback", th_n_rollback}
});
}
}
{
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
@ -1189,14 +1221,26 @@ struct server_context {
}
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
slot.sampled = result.tok;
// search stop word and delete it
slot.generated_text += token_str;
slot.has_next_token = true;
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
bool is_token_healing = false;
if (slot.n_th_prefix > 0) {
if (slot.n_th_prefix < token_str.size()) {
slot.generated_text += token_str.substr(slot.n_th_prefix);
slot.n_th_prefix = 0;
is_token_healing = false; // to send partial token text when streaming
} else {
slot.n_th_prefix -= token_str.size();
is_token_healing = true;
}
} else {
slot.generated_text += token_str;
}
// remember which tokens were sampled - used for repetition penalties during sampling
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
// we can change penalty_prompt_tokens because it is always created from scratch each request
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
@ -1224,7 +1268,7 @@ struct server_context {
break;
}
if (!incomplete) {
if (!incomplete && !is_token_healing) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
@ -1256,7 +1300,7 @@ struct server_context {
}
}
if (incomplete) {
if (incomplete || is_token_healing) {
slot.has_next_token = true;
}
@ -1361,7 +1405,8 @@ struct server_context {
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence}
{"samplers", samplers_sequence},
{"token_healing_enabled", slot.sparams.token_healing_enabled}
};
}
@ -2106,6 +2151,21 @@ struct server_context {
continue;
}
// Roll back prompt tokens if token healing
llama_token_healing_output token_healing_out{};
if (slot.sparams.token_healing_enabled) {
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
prompt_tokens, slot.sparams.token_healing_n_rollback);
slot.n_th_prefix = token_healing_out.prefix.size();
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("token healing prompt", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"removed_suffix", token_healing_out.prefix},
{"n_tokens_removed", token_healing_out.n_tokens_removed}
});
}
if (slot.embedding) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) {
@ -2156,6 +2216,9 @@ struct server_context {
}
llama_sampling_reset(slot.ctx_sampling);
if (slot.sparams.token_healing_enabled) {
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
}
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;