server : add token healing support
This commit is contained in:
parent
fc8773d309
commit
d5eea13797
2 changed files with 72 additions and 7 deletions
|
@ -436,6 +436,8 @@ node index.js
|
||||||
|
|
||||||
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
|
||||||
|
|
||||||
|
`token_healing`: Set token healing strategy. Default: `0`, which is disabled.
|
||||||
|
|
||||||
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
|
||||||
|
|
||||||
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
|
||||||
|
|
|
@ -185,6 +185,7 @@ struct server_slot {
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
size_t n_sent_token_probs = 0;
|
size_t n_sent_token_probs = 0;
|
||||||
|
size_t n_th_prefix = 0; // size of remaining token healing prefix
|
||||||
|
|
||||||
int64_t t_start_process_prompt;
|
int64_t t_start_process_prompt;
|
||||||
int64_t t_start_generation;
|
int64_t t_start_generation;
|
||||||
|
@ -206,6 +207,7 @@ struct server_slot {
|
||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
n_th_prefix = 0;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -1094,6 +1096,36 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto & token_healing_str = data.find("token_healing");
|
||||||
|
auto & th_enabled = slot.sparams.token_healing_enabled;
|
||||||
|
th_enabled = default_sparams.token_healing_enabled;
|
||||||
|
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||||
|
const auto value = token_healing_str->get<std::string>();
|
||||||
|
auto & th_type = slot.sparams.token_healing_type;
|
||||||
|
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
|
||||||
|
th_enabled = true;
|
||||||
|
/**/ if (value == "0" ) { th_enabled = false; }
|
||||||
|
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||||
|
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
|
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
|
else if (value[0] == 'r' ) {
|
||||||
|
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
|
th_n_rollback = std::stoi(value.substr(1));
|
||||||
|
if (th_n_rollback <= 0) {
|
||||||
|
th_enabled = false;
|
||||||
|
}
|
||||||
|
} else { th_enabled = false; }
|
||||||
|
|
||||||
|
LOG_VERBOSE("token healing", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"enabled", th_enabled},
|
||||||
|
{"type", th_type},
|
||||||
|
{"n_rollback", th_n_rollback}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
if (slot.ctx_sampling != nullptr) {
|
if (slot.ctx_sampling != nullptr) {
|
||||||
llama_sampling_free(slot.ctx_sampling);
|
llama_sampling_free(slot.ctx_sampling);
|
||||||
|
@ -1189,14 +1221,26 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||||
slot.sampled = result.tok;
|
slot.sampled = result.tok;
|
||||||
|
|
||||||
// search stop word and delete it
|
|
||||||
slot.generated_text += token_str;
|
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
|
// Suppress generating the token healing prefix to not repeat the input prompt's suffix
|
||||||
|
bool is_token_healing = false;
|
||||||
|
if (slot.n_th_prefix > 0) {
|
||||||
|
if (slot.n_th_prefix < token_str.size()) {
|
||||||
|
slot.generated_text += token_str.substr(slot.n_th_prefix);
|
||||||
|
slot.n_th_prefix = 0;
|
||||||
|
is_token_healing = false; // to send partial token text when streaming
|
||||||
|
} else {
|
||||||
|
slot.n_th_prefix -= token_str.size();
|
||||||
|
is_token_healing = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slot.generated_text += token_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
||||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||||
|
@ -1224,7 +1268,7 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!incomplete) {
|
if (!incomplete && !is_token_healing) {
|
||||||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
||||||
|
|
||||||
const std::string str_test = slot.generated_text.substr(pos);
|
const std::string str_test = slot.generated_text.substr(pos);
|
||||||
|
@ -1256,7 +1300,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (incomplete) {
|
if (incomplete || is_token_healing) {
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1361,7 +1405,8 @@ struct server_context {
|
||||||
{"n_probs", slot.sparams.n_probs},
|
{"n_probs", slot.sparams.n_probs},
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
{"samplers", samplers_sequence}
|
{"samplers", samplers_sequence},
|
||||||
|
{"token_healing_enabled", slot.sparams.token_healing_enabled}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2106,6 +2151,21 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Roll back prompt tokens if token healing
|
||||||
|
llama_token_healing_output token_healing_out{};
|
||||||
|
if (slot.sparams.token_healing_enabled) {
|
||||||
|
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
||||||
|
prompt_tokens, slot.sparams.token_healing_n_rollback);
|
||||||
|
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||||
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
LOG_VERBOSE("token healing prompt", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task},
|
||||||
|
{"removed_suffix", token_healing_out.prefix},
|
||||||
|
{"n_tokens_removed", token_healing_out.n_tokens_removed}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.embedding) {
|
if (slot.embedding) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
|
@ -2156,6 +2216,9 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
if (slot.sparams.token_healing_enabled) {
|
||||||
|
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||||
|
}
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
if (!slot.params.cache_prompt) {
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue