From d5eea137977d57741eff9913e3304c5f10fcef90 Mon Sep 17 00:00:00 2001 From: mare5x Date: Wed, 26 Jun 2024 17:12:57 +0200 Subject: [PATCH] server : add token healing support --- examples/server/README.md | 2 + examples/server/server.cpp | 77 ++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe8..f2cea4741 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -436,6 +436,8 @@ node index.js `json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. + `token_healing`: Set token healing strategy. Default: `0`, which is disabled. + `seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. `ignore_eos`: Ignore end of stream token and continue generating. Default: `false` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 360f571e4..0d556ac24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -185,6 +185,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character size_t n_sent_token_probs = 0; + size_t n_th_prefix = 0; // size of remaining token healing prefix int64_t t_start_process_prompt; int64_t t_start_generation; @@ -206,6 +207,7 @@ struct server_slot { infill = false; ga_i = 0; n_past_se = 0; + n_th_prefix = 0; generated_token_probs.clear(); } @@ -1094,6 +1096,36 @@ struct server_context { } } + { + const auto & token_healing_str = data.find("token_healing"); + auto & th_enabled = slot.sparams.token_healing_enabled; + th_enabled = default_sparams.token_healing_enabled; + if (token_healing_str != data.end() && token_healing_str->is_string()) { + const auto value = token_healing_str->get(); + auto & th_type = slot.sparams.token_healing_type; + auto & th_n_rollback = slot.sparams.token_healing_n_rollback; + th_enabled = true; + /**/ if (value == "0" ) { th_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + th_enabled = false; + } + } else { th_enabled = false; } + + LOG_VERBOSE("token healing", { + {"id_slot", slot.id}, + {"enabled", th_enabled}, + {"type", th_type}, + {"n_rollback", th_n_rollback} + }); + } + } + { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); @@ -1189,14 +1221,26 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; - - // search stop word and delete it - slot.generated_text += token_str; slot.has_next_token = true; + // Suppress generating the token healing prefix to not repeat the input prompt's suffix + bool is_token_healing = false; + if (slot.n_th_prefix > 0) { + if (slot.n_th_prefix < token_str.size()) { + slot.generated_text += token_str.substr(slot.n_th_prefix); + slot.n_th_prefix = 0; + is_token_healing = false; // to send partial token text when streaming + } else { + slot.n_th_prefix -= token_str.size(); + is_token_healing = true; + } + } else { + slot.generated_text += token_str; + } + + // remember which tokens were sampled - used for repetition penalties during sampling if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { // we can change penalty_prompt_tokens because it is always created from scratch each request slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); @@ -1224,7 +1268,7 @@ struct server_context { break; } - if (!incomplete) { + if (!incomplete && !is_token_healing) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1256,7 +1300,7 @@ struct server_context { } } - if (incomplete) { + if (incomplete || is_token_healing) { slot.has_next_token = true; } @@ -1361,7 +1405,8 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers_sequence}, + {"token_healing_enabled", slot.sparams.token_healing_enabled} }; } @@ -2106,6 +2151,21 @@ struct server_context { continue; } + // Roll back prompt tokens if token healing + llama_token_healing_output token_healing_out{}; + if (slot.sparams.token_healing_enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prompt_tokens, slot.sparams.token_healing_n_rollback); + slot.n_th_prefix = token_healing_out.prefix.size(); + slot.n_prompt_tokens = prompt_tokens.size(); + LOG_VERBOSE("token healing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"removed_suffix", token_healing_out.prefix}, + {"n_tokens_removed", token_healing_out.n_tokens_removed} + }); + } + if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { @@ -2156,6 +2216,9 @@ struct server_context { } llama_sampling_reset(slot.ctx_sampling); + if (slot.sparams.token_healing_enabled) { + llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); + } if (!slot.params.cache_prompt) { slot.n_past_se = 0;