From e6a5a6c6ac8e5e5eee02c29dd747250c008486c4 Mon Sep 17 00:00:00 2001 From: Paulo de Castro Date: Sun, 28 Apr 2024 00:34:07 -0300 Subject: [PATCH] server : avoid breaking KV cache when prompt >= n_ctx (#6855) --- examples/server/server.cpp | 92 +++++++++++++++++-- examples/server/tests/features/steps/steps.py | 47 +++++++++- .../server/tests/features/truncation.feature | 44 +++++++++ examples/server/tests/tests.sh | 2 +- 4 files changed, 177 insertions(+), 8 deletions(-) create mode 100644 examples/server/tests/features/truncation.feature diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 47bea1591..8d6657e65 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -113,9 +113,11 @@ struct slot_params { bool stream = true; bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict + uint32_t seed = -1; // RNG seed + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_truncate = 0; // number of tokens after n_keep that will be discarded when the prompt is bigger than the context + int32_t n_predict = -1; // new tokens to predict std::vector antiprompt; @@ -159,6 +161,7 @@ struct server_slot { bool infill = false; bool embedding = false; bool has_next_token = true; + bool shifted = false; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -194,6 +197,7 @@ struct server_slot { void reset() { n_prompt_tokens = 0; generated_text = ""; + shifted = false; truncated = false; stopped_eos = false; stopped_word = false; @@ -916,6 +920,7 @@ struct server_context { slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate); slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); @@ -1344,6 +1349,7 @@ struct server_context { {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, + {"n_truncate", slot.params.n_truncate}, {"ignore_eos", ignore_eos}, {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, @@ -1431,6 +1437,7 @@ struct server_context { {"tokens_evaluated", slot.n_prompt_tokens}, {"generation_settings", get_formated_generation(slot)}, {"prompt", slot.prompt}, + {"shifted", slot.shifted}, {"truncated", slot.truncated}, {"stopped_eos", slot.stopped_eos}, {"stopped_word", slot.stopped_word}, @@ -1886,6 +1893,7 @@ struct server_context { {"n_past", slot.n_past}, {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()}, + {"shifted", slot.shifted}, {"truncated", slot.truncated} }); @@ -1959,7 +1967,7 @@ struct server_context { slot.n_past -= n_discard; - slot.truncated = true; + slot.shifted = true; } } } @@ -1994,6 +2002,7 @@ struct server_context { {"n_past", slot.n_past}, {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()}, + {"shifted", slot.shifted}, {"truncated", slot.truncated} }); } @@ -2100,9 +2109,11 @@ struct server_context { // if input prompt is too big, truncate it (if group attention self-extend is disabled) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; + const int n_truncate = slot.params.n_truncate; const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size; + const int erased_tokens = n_truncate ? n_truncate : half_ctx; std::vector new_tokens( prompt_tokens.begin(), @@ -2110,7 +2121,7 @@ struct server_context { new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.begin() + slot.params.n_keep + erased_tokens, prompt_tokens.end()); prompt_tokens = std::move(new_tokens); @@ -2118,6 +2129,17 @@ struct server_context { slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); + if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) { + LOG_ERROR("prompt - n_truncate > n_ctx", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_truncate", n_truncate}, + }); + } + LOG_VERBOSE("input truncated", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -2128,6 +2150,34 @@ struct server_context { {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, }); + if (n_truncate && slot.params.cache_prompt) { + const int n_keep = slot.params.n_keep + add_bos_token; + for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i]; + } + + const int new_cache_size = slot.cache_tokens.size() - n_truncate; + if (new_cache_size >= 0) { + slot.cache_tokens.resize(new_cache_size); + + LOG_VERBOSE("cache tokens shift", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_truncate", n_truncate}, + {"new_cache_size", new_cache_size}, + {"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())}, + }); + } else { + LOG_ERROR("n_truncate needs to be used with cache_prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + }); + } + } + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } @@ -2142,6 +2192,12 @@ struct server_context { // reuse any previously computed tokens that are common with the new prompt slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + LOG_INFO("[cache_tokens, prompt_tokens]", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "common_part", slot.n_past} + }); + // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); @@ -2172,6 +2228,30 @@ struct server_context { } } + // shift KV cache if needed + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_truncate = slot.params.n_truncate; + if (n_truncate && slot.params.cache_prompt && slot.truncated) { + llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate); + + LOG_INFO("kv cache rm", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", n_keep }, + { "p1", n_keep + n_truncate } + }); + + llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate); + + LOG_INFO("kv cache add", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", n_keep + n_truncate }, + { "p1", slot.n_past }, + { "delta", -n_truncate } + }); + } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index df0814cc9..98bbbe27d 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -51,6 +51,8 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.n_ga = None context.n_ga_w = None context.n_predict = None + context.n_keep = 0 + context.n_truncate = 0 context.n_prompts = 0 context.n_server_predict = None context.slot_save_path = None @@ -71,6 +73,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.response_format = None context.temperature = None + context.stop_string = [] context.tasks_result = [] context.concurrent_tasks = [] context.prompts = [] @@ -241,19 +244,35 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i @step('a completion request with {api_error} api error') @async_run_until_complete async def step_request_completion(context, api_error: Literal['raised'] | str): + await make_completion_request(context, api_error) + + +@step('an ongoing completion request') +@async_run_until_complete +async def step_request_ongoing_completion(context): + await make_completion_request(context, '', True) + + +async def make_completion_request(context, api_error: Literal['raised'] | str, ongoing=False): expect_api_error = api_error == 'raised' seeds = await completions_seed(context, num_seeds=1) - completion = await request_completion(context.prompts.pop(), + prompt = context.prompts[-1] if ongoing else context.prompts.pop() + completion = await request_completion(prompt, seeds[0] if seeds is not None else seeds, context.base_url, debug=context.debug, n_predict=context.n_predict, + n_keep=context.n_keep, + n_truncate=context.n_truncate, cache_prompt=context.cache_prompt, + stop_string=context.stop_string, id_slot=context.id_slot, expect_api_error=expect_api_error, user_api_key=context.user_api_key, temperature=context.temperature) context.tasks_result.append(completion) + if ongoing and isinstance(completion, dict) and not expect_api_error: + context.prompts[-1] += completion['content'] if context.debug: print(f"Completion response: {completion}") if expect_api_error: @@ -336,6 +355,16 @@ def step_max_tokens(context, max_tokens): context.n_predict = max_tokens +@step('{n_keep:d} tokens to keep') +def step_keep_tokens(context, n_keep): + context.n_keep = n_keep + + +@step('{n_truncate:d} tokens to truncate') +def step_truncate_tokens(context, n_truncate): + context.n_truncate = n_truncate + + @step('a response format {response_format}') def step_response_format(context, response_format): context.response_format = json.loads(response_format) @@ -399,6 +428,11 @@ def step_bos_token(context, bos): context.bos = bos +@step('a list of stop strings {stop_list}') +def step_stop_string(context, stop_list): + context.stop_string = json.loads(stop_list) + + @step('a prefix prompt') def step_prompt_prefix(context): context.prompt_prefix = context_text(context) @@ -510,6 +544,11 @@ def step_many_prompts(context, num_prompts, prompt, seed): context.n_prompts = len(context.prompts) +@step('an ongoing prompt') +def step_a_ongoing_prompt(context): + context.prompts[-1] += context_text(context) + + @step('concurrent completion requests') @async_run_until_complete() async def step_concurrent_completion_requests(context): @@ -861,7 +900,10 @@ async def request_completion(prompt, prompt_prefix=None, prompt_suffix=None, n_predict=None, + n_keep=0, + n_truncate=0, cache_prompt=False, + stop_string=None, id_slot=None, expect_api_error=None, user_api_key=None, @@ -884,7 +926,10 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "n_keep": n_keep, + "n_truncate": n_truncate, "cache_prompt": cache_prompt, + "stop": stop_string if stop_string is not None else [], "id_slot": id_slot, "seed": seed if seed is not None else 42, "temperature": temperature if temperature is not None else 0.8, diff --git a/examples/server/tests/features/truncation.feature b/examples/server/tests/features/truncation.feature new file mode 100644 index 000000000..0a309df60 --- /dev/null +++ b/examples/server/tests/features/truncation.feature @@ -0,0 +1,44 @@ +# run with: ./tests.sh --no-skipped --tags truncation +@truncation +@slow +Feature: Chat truncation + + Background: Server startup + Given a server listening on localhost:8080 + And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models + And prompt caching is enabled + And a list of stop strings ["\n"] + And 82 tokens to keep + And 256 KV cache size + And 32 server max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Correctly truncate the prompt when the prompt exceeds the context size + Given a prompt: + """ + Continue the chat below. + Me: Hey there, how's it going? + You: I'm doing well, thanks for asking! How are you? + Me: I'm doing good, just trying to get some work done. How's your day? + You: My day has been pretty productive so far. I've been working on a new project. + Me: That's great to hear! What's the new project you're working on? + You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it. + Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch? + You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it. + Me: That's really nice, are you happy with the progress so far? + + """ + And an ongoing completion request + Then -1 tokens are predicted matching You: + Given an ongoing prompt: + """ + + Me: I have one more question for you my friend. What's the most value thing you learned during your development journey? + + """ + And 52 tokens to truncate + And a completion request with no api error + Then -1 tokens are predicted matching You: + # 28 because '\n' stop string is not pushed to the context + And 28 prompt tokens are processed diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 72a0fbad8..56cfb3429 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -5,7 +5,7 @@ set -eu if [ $# -lt 1 ] then # Start @llama.cpp scenario - behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey|truncation' --tags llama.cpp else behave "$@" fi