From 3ba5c55bc47c97136fab135f07526d8bbd6be667 Mon Sep 17 00:00:00 2001 From: mare5x Date: Sun, 30 Jun 2024 22:30:15 +0200 Subject: [PATCH] server : token healing for infilling/FIM --- examples/server/server.cpp | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0d556ac24..01aac8ed5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2088,6 +2088,8 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; + llama_token_healing_output token_healing_out{}; + if (slot.infill) { const bool add_bos = llama_should_add_bos_token(model); bool suff_rm_leading_spc = true; @@ -2107,6 +2109,12 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + if (slot.sparams.token_healing_enabled) { + // For FIM roll back only the prefix part (i.e. cursor location) + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prefix_tokens, slot.sparams.token_healing_n_rollback); + } + auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; if (add_bos) { @@ -2122,6 +2130,11 @@ struct server_context { prompt_tokens = embd_inp; } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + + if (slot.sparams.token_healing_enabled) { + token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, + prompt_tokens, slot.sparams.token_healing_n_rollback); + } } slot.n_past = 0; @@ -2136,6 +2149,16 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); + if (slot.sparams.token_healing_enabled) { + slot.n_th_prefix = token_healing_out.prefix.size(); + LOG_VERBOSE("token healing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"removed_suffix", token_healing_out.prefix}, + {"n_tokens_removed", token_healing_out.n_tokens_removed} + }); + } + // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { LOG_INFO("empty prompt - releasing slot", { @@ -2151,21 +2174,6 @@ struct server_context { continue; } - // Roll back prompt tokens if token healing - llama_token_healing_output token_healing_out{}; - if (slot.sparams.token_healing_enabled) { - token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, - prompt_tokens, slot.sparams.token_healing_n_rollback); - slot.n_th_prefix = token_healing_out.prefix.size(); - slot.n_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("token healing prompt", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"removed_suffix", token_healing_out.prefix}, - {"n_tokens_removed", token_healing_out.n_tokens_removed} - }); - } - if (slot.embedding) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) {