From 00ae55b3883450f7cce89fd6e24109cf130fc269 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Oct 2023 20:09:01 +0300 Subject: [PATCH] server : hide ctx_sampling->prev behind API (#3696) --- examples/server/server.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 686a8c7c3..bd16100f3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1559,7 +1559,8 @@ struct llama_server_context if (!slot.params.cache_prompt) { - std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0); + llama_sampling_reset(slot.ctx_sampling); + slot.n_past = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } @@ -1570,16 +1571,17 @@ struct llama_server_context slot.params.n_keep = slot.num_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - //if input prompt is too big, truncate like normal + + // if input prompt is too big, truncate it if (slot.num_prompt_tokens >= slot.n_ctx) { - // applied bug of #3661 const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); - // Use half the left-over space in the context for the prompt new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); + LOG_VERBOSE("input truncated", { {"n_ctx", slot.n_ctx}, {"n_keep", slot.params.n_keep}, @@ -1588,14 +1590,20 @@ struct llama_server_context }); slot.truncated = true; prompt_tokens = new_tokens; + slot.num_prompt_tokens = prompt_tokens.size(); GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); } - const size_t ps = slot.num_prompt_tokens; - std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps); + + // push the prompt into the sampling context (do not apply grammar) + for (auto &token : prompt_tokens) + { + llama_sampling_accept(slot.ctx_sampling, ctx, token, false); + } + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); }