diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index f71e0d706..21b08ecd6 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -47,6 +47,8 @@ def step_server_config(context, server_fqdn, server_port): context.n_ga = None context.n_ga_w = None context.n_predict = None + context.n_keep = 0 + context.n_truncate = 0 context.n_prompts = 0 context.n_server_predict = None context.slot_save_path = None @@ -66,6 +68,7 @@ def step_server_config(context, server_fqdn, server_port): context.user_api_key = None context.response_format = None + context.stop_string = [] context.tasks_result = [] context.concurrent_tasks = [] context.prompts = [] @@ -168,6 +171,7 @@ def step_start_server(context): addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM) family, typ, proto, _, sockaddr = addrs[0] + print(sockaddr) while True: with closing(socket.socket(family, typ, proto)) as sock: @@ -231,17 +235,33 @@ async def step_all_slots_status(context, expected_slot_status_string): @step('a completion request with {api_error} api error') @async_run_until_complete async def step_request_completion(context, api_error): + await make_completion_request(context, api_error) + + +@step('an ongoing completion request') +@async_run_until_complete +async def step_request_ongoing_completion(context): + await make_completion_request(context, '', True) + + +async def make_completion_request(context, api_error, ongoing=False): expect_api_error = api_error == 'raised' - completion = await request_completion(context.prompts.pop(), + prompt = context.prompts[-1] if ongoing else context.prompts.pop() + completion = await request_completion(prompt, context.base_url, debug=context.debug, n_predict=context.n_predict, + n_keep=context.n_keep, + n_truncate=context.n_truncate, cache_prompt=context.cache_prompt, + stop_string=context.stop_string, id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) context.tasks_result.append(completion) + if ongoing and not expect_api_error: + context.prompts[-1] += completion['content'] if context.debug: print(f"Completion response: {completion}") if expect_api_error: @@ -306,6 +326,16 @@ def step_max_tokens(context, max_tokens): context.n_predict = max_tokens +@step('{n_keep:d} tokens to keep') +def step_keep_tokens(context, n_keep): + context.n_keep = n_keep + + +@step('{n_truncate:d} tokens to truncate') +def step_truncate_tokens(context, n_truncate): + context.n_truncate = n_truncate + + @step('a response format {response_format}') def step_response_format(context, response_format): context.response_format = json.loads(response_format) @@ -355,6 +385,10 @@ def step_n_ubatch(context, n_ubatch): def step_seed(context, seed): context.seed = seed +@step('a list of stop strings {stop_list}') +def step_stop_string(context, stop_list): + context.stop_string = json.loads(stop_list) + @step('a prefix prompt') def step_prompt_prefix(context): @@ -457,6 +491,11 @@ def step_a_prompt_prompt(context, prompt): context.n_prompts = len(context.prompts) +@step('an ongoing prompt') +def step_a_ongoing_prompt(context): + context.prompts[-1] += context_text(context) + + @step('concurrent completion requests') @async_run_until_complete() async def step_concurrent_completion_requests(context): @@ -786,7 +825,10 @@ async def request_completion(prompt, prompt_prefix=None, prompt_suffix=None, n_predict=None, + n_keep=0, + n_truncate=0, cache_prompt=False, + stop_string=None, id_slot=None, seed=None, expect_api_error=None, @@ -809,7 +851,10 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "n_keep": n_keep, + "n_truncate": n_truncate, "cache_prompt": cache_prompt, + "stop": stop_string if stop_string is not None else [], "id_slot": id_slot, "seed": seed if seed is not None else 42 }, diff --git a/examples/server/tests/features/truncation.feature b/examples/server/tests/features/truncation.feature new file mode 100644 index 000000000..74ca992f7 --- /dev/null +++ b/examples/server/tests/features/truncation.feature @@ -0,0 +1,44 @@ +# run with: ./tests.sh --no-skipped --tags truncation +@trucation +@slow +Feature: Chat truncation + + Background: Server startup + Given a server listening on localhost:8080 + And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models + And prompt caching is enabled + And a list of stop strings ["\n"] + And 82 tokens to keep + And 256 KV cache size + And 32 server max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Correctly truncate the prompt when the prompt exceeds the context size + Given a prompt: + """ + Continue the chat below. + Me: Hey there, how's it going? + You: I'm doing well, thanks for asking! How are you? + Me: I'm doing good, just trying to get some work done. How's your day? + You: My day has been pretty productive so far. I've been working on a new project. + Me: That's great to hear! What's the new project you're working on? + You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it. + Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch? + You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it. + Me: That's really nice, are you happy with the progress so far? + + """ + And an ongoing completion request + Then -1 tokens are predicted matching You: + Given an ongoing prompt: + """ + + Me: I have one more question for you my friend. What's the most value thing you learned during your development journey? + + """ + And 52 tokens to truncate + And a completion request with no api error + Then -1 tokens are predicted matching You: + # 28 because '\n' stop string is not pushed to the context + And 28 prompt tokens are processed \ No newline at end of file diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 72a0fbad8..56cfb3429 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -5,7 +5,7 @@ set -eu if [ $# -lt 1 ] then # Start @llama.cpp scenario - behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey|truncation' --tags llama.cpp else behave "$@" fi