server : add token healing support
This commit is contained in:
parent
fc8773d309
commit
d5eea13797
2 changed files with 72 additions and 7 deletions
|
@ -436,6 +436,8 @@ node index.js
|
|||
|
||||
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
||||
|
||||
`token_healing`: Set token healing strategy. Default: `0`, which is disabled.
|
||||
|
||||
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
||||
|
||||
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
||||
|
|
|
@ -185,6 +185,7 @@ struct server_slot {
|
|||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
size_t n_sent_token_probs = 0;
|
||||
size_t n_th_prefix = 0; // size of remaining token healing prefix
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
|
@ -206,6 +207,7 @@ struct server_slot {
|
|||
infill = false;
|
||||
ga_i = 0;
|
||||
n_past_se = 0;
|
||||
n_th_prefix = 0;
|
||||
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
@ -1094,6 +1096,36 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto & token_healing_str = data.find("token_healing");
|
||||
auto & th_enabled = slot.sparams.token_healing_enabled;
|
||||
th_enabled = default_sparams.token_healing_enabled;
|
||||
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||
const auto value = token_healing_str->get<std::string>();
|
||||
auto & th_type = slot.sparams.token_healing_type;
|
||||
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
|
||||
th_enabled = true;
|
||||
/**/ if (value == "0" ) { th_enabled = false; }
|
||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
th_enabled = false;
|
||||
}
|
||||
} else { th_enabled = false; }
|
||||
|
||||
LOG_VERBOSE("token healing", {
|
||||
{"id_slot", slot.id},
|
||||
{"enabled", th_enabled},
|
||||
{"type", th_type},
|
||||
{"n_rollback", th_n_rollback}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
|
@ -1189,14 +1221,26 @@ struct server_context {
|
|||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
||||
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
|
||||
bool is_token_healing = false;
|
||||
if (slot.n_th_prefix > 0) {
|
||||
if (slot.n_th_prefix < token_str.size()) {
|
||||
slot.generated_text += token_str.substr(slot.n_th_prefix);
|
||||
slot.n_th_prefix = 0;
|
||||
is_token_healing = false; // to send partial token text when streaming
|
||||
} else {
|
||||
slot.n_th_prefix -= token_str.size();
|
||||
is_token_healing = true;
|
||||
}
|
||||
} else {
|
||||
slot.generated_text += token_str;
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||
|
@ -1224,7 +1268,7 @@ struct server_context {
|
|||
break;
|
||||
}
|
||||
|
||||
if (!incomplete) {
|
||||
if (!incomplete && !is_token_healing) {
|
||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||
|
||||
const std::string str_test = slot.generated_text.substr(pos);
|
||||
|
@ -1256,7 +1300,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
if (incomplete) {
|
||||
if (incomplete || is_token_healing) {
|
||||
slot.has_next_token = true;
|
||||
}
|
||||
|
||||
|
@ -1361,7 +1405,8 @@ struct server_context {
|
|||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
{"samplers", samplers_sequence},
|
||||
{"token_healing_enabled", slot.sparams.token_healing_enabled}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2106,6 +2151,21 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Roll back prompt tokens if token healing
|
||||
llama_token_healing_output token_healing_out{};
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
||||
prompt_tokens, slot.sparams.token_healing_n_rollback);
|
||||
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
LOG_VERBOSE("token healing prompt", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"removed_suffix", token_healing_out.prefix},
|
||||
{"n_tokens_removed", token_healing_out.n_tokens_removed}
|
||||
});
|
||||
}
|
||||
|
||||
if (slot.embedding) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
|
@ -2156,6 +2216,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||
}
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue