From 3f96ab04a6da3ee3b766d3a7d957fee2696910bd Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 19:12:59 +0100 Subject: [PATCH] `server`: fix cancel tests --- examples/server/server.cpp | 7 +-- examples/server/tests/features/cancel.feature | 44 +++++++++++++------ examples/server/tests/features/steps/steps.py | 34 +++++++------- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 66f6c4980..0869e4623 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2349,6 +2349,7 @@ struct server_context { completion_token_output result; if (params.testing_sampler_delay_millis > 0) { + LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis); std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); } const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); @@ -3006,7 +3007,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); bool stream = json_value(data, "stream", false); - + handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); @@ -3136,7 +3137,7 @@ int main(int argc, char ** argv) { return; } - + handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); @@ -3164,7 +3165,7 @@ int main(int argc, char ** argv) { json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - + res_ok(res, &sink, root); }); }; diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature index 54ded24c6..241507024 100644 --- a/examples/server/tests/features/cancel.feature +++ b/examples/server/tests/features/cancel.feature @@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests Background: Server startup Given a server listening on localhost:8080 - And 500 milliseconds delay in sampler for testing And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file test-model.gguf And a model alias tinyllama-2 @@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 # see --ctx-size and #5568 - And 256 KV cache size + And 512 KV cache size And 32 as batch size - And 1 slots + And 2 slots And 64 server max tokens to predict + And prometheus compatible metrics exposed + And 300 milliseconds delay in sampler for testing + And no warmup Then the server is starting Then the server is healthy + # Then the server is healthy with timeout 10 seconds - # Scenario: Health - # Then the server is ready - # And all slots are idle - @wip - Scenario Outline: Cancelling completion request frees up slot - Given a prompt: - """ - Once upon - """ + Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming ) + Given a model llama-2 + And a user prompt Once upon a time + And a system prompt You tell lengthy stories And 256 max tokens to predict And 256 server max tokens to predict And streaming is - And a completion request cancelled after 100 milliseconds - # And wait for 50 milliseconds + And disconnect after 100 milliseconds + Given concurrent OAI completions requests + And wait for 700 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | + + + Scenario Outline: Cancelling a completion request frees up slot (streaming ) + Given a model llama-2 + Given a prompt Once upon a time + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And disconnect after 100 milliseconds + Given a completion request with no api error + And wait for 700 milliseconds Then all slots are idle Examples: Prompts diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 5bc4b0631..561fc03ff 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.lora_file = None context.testing_sampler_delay_millis = None context.disable_ctx_shift = False + context.disconnect_after_millis = None context.tasks_result = [] context.concurrent_tasks = [] @@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): n_predict=context.n_predict, cache_prompt=context.cache_prompt, id_slot=context.id_slot, + disconnect_after_millis=context.disconnect_after_millis, expect_api_error=expect_api_error, user_api_key=context.user_api_key, temperature=context.temperature) @@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): async def step_request_completion(context, millis: int): await asyncio.sleep(millis / 1000.0) -@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds') + +@step('disconnect after {disconnect_after_millis:d} milliseconds') @async_run_until_complete -async def step_request_completion(context, disconnect_after_millis: int): - seeds = await completions_seed(context, num_seeds=1) - await request_completion(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.base_url, - debug=context.debug, - n_predict=context.n_predict, - cache_prompt=context.cache_prompt, - id_slot=context.id_slot, - disconnect_after_millis=disconnect_after_millis, - user_api_key=context.user_api_key, - temperature=context.temperature) +async def step_disconnect_after(context, disconnect_after_millis: int): + context.disconnect_after_millis = disconnect_after_millis + @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): @@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error): print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' seeds = await completions_seed(context, num_seeds=1), + seeds = await completions_seed(context, num_seeds=1) completion = await oai_chat_completions(context.prompts.pop(), seeds[0] if seeds is not None else seeds, context.system_prompt, @@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error): user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, + disconnect_after_millis=context.disconnect_after_millis, + expect_api_error=expect_api_error) context.tasks_result.append(completion) if context.debug: @@ -606,6 +603,7 @@ async def step_oai_chat_completions(context): if hasattr(context, 'enable_streaming') else None, response_format=context.response_format if hasattr(context, 'response_format') else None, + disconnect_after_millis=context.disconnect_after_millis, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -1029,9 +1027,9 @@ async def request_completion(prompt, }, headers=headers) as response: if disconnect_after_millis is not None: - await asyncio.sleep(disconnect_after_millis / 1000) + await asyncio.sleep(disconnect_after_millis / 1000.0) return 0 - + if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt, temperature=None, model=None, n_predict=None, + disconnect_after_millis=None, enable_streaming=None, response_format=None, user_api_key=None, @@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt, async with session.post(f'{base_url}{base_path}', json=payload, headers=headers) as response: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 + if enable_streaming: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt, else: return response.status else: + assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client" try: openai.api_key = user_api_key openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'