server: fix cancel tests

This commit is contained in:
ochafik 2024-09-29 19:12:59 +01:00
parent 88c9b5497a
commit 3f96ab04a6
3 changed files with 53 additions and 32 deletions

View file

@ -2349,6 +2349,7 @@ struct server_context {
completion_token_output result; completion_token_output result;
if (params.testing_sampler_delay_millis > 0) { if (params.testing_sampler_delay_millis > 0) {
LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis);
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
} }
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
@ -3006,7 +3007,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) { handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function<bool()> & is_alive) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
@ -3136,7 +3137,7 @@ int main(int argc, char ** argv) {
return; return;
} }
handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) { handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function<bool()> & is_alive) {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
@ -3164,7 +3165,7 @@ int main(int argc, char ** argv) {
json root = is_openai json root = is_openai
? format_embeddings_response_oaicompat(body, responses) ? format_embeddings_response_oaicompat(body, responses)
: responses[0]; : responses[0];
res_ok(res, &sink, root); res_ok(res, &sink, root);
}); });
}; };

View file

@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And 500 milliseconds delay in sampler for testing
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf And a model file test-model.gguf
And a model alias tinyllama-2 And a model alias tinyllama-2
@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests
# KV Cache corresponds to the total amount of tokens # KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130 # that can be stored across all independent sequences: #4130
# see --ctx-size and #5568 # see --ctx-size and #5568
And 256 KV cache size And 512 KV cache size
And 32 as batch size And 32 as batch size
And 1 slots And 2 slots
And 64 server max tokens to predict And 64 server max tokens to predict
And prometheus compatible metrics exposed
And 300 milliseconds delay in sampler for testing
And no warmup
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
# Then the server is healthy with timeout 10 seconds
# Scenario: Health
# Then the server is ready
# And all slots are idle
@wip Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming <enable_streaming>)
Scenario Outline: Cancelling completion request frees up slot Given a model llama-2
Given a prompt: And a user prompt Once upon a time
""" And a system prompt You tell lengthy stories
Once upon
"""
And 256 max tokens to predict And 256 max tokens to predict
And 256 server max tokens to predict And 256 server max tokens to predict
And streaming is <enable_streaming> And streaming is <enable_streaming>
And a completion request cancelled after 100 milliseconds And disconnect after 100 milliseconds
# And wait for 50 milliseconds Given concurrent OAI completions requests
And wait for 700 milliseconds
Then all slots are idle
Examples: Prompts
| enable_streaming |
| disabled |
| enabled |
Scenario Outline: Cancelling a completion request frees up slot (streaming <enable_streaming>)
Given a model llama-2
Given a prompt Once upon a time
And 256 max tokens to predict
And 256 server max tokens to predict
And streaming is <enable_streaming>
And disconnect after 100 milliseconds
Given a completion request with no api error
And wait for 700 milliseconds
Then all slots are idle Then all slots are idle
Examples: Prompts Examples: Prompts

View file

@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.lora_file = None context.lora_file = None
context.testing_sampler_delay_millis = None context.testing_sampler_delay_millis = None
context.disable_ctx_shift = False context.disable_ctx_shift = False
context.disconnect_after_millis = None
context.tasks_result = [] context.tasks_result = []
context.concurrent_tasks = [] context.concurrent_tasks = []
@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
n_predict=context.n_predict, n_predict=context.n_predict,
cache_prompt=context.cache_prompt, cache_prompt=context.cache_prompt,
id_slot=context.id_slot, id_slot=context.id_slot,
disconnect_after_millis=context.disconnect_after_millis,
expect_api_error=expect_api_error, expect_api_error=expect_api_error,
user_api_key=context.user_api_key, user_api_key=context.user_api_key,
temperature=context.temperature) temperature=context.temperature)
@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
async def step_request_completion(context, millis: int): async def step_request_completion(context, millis: int):
await asyncio.sleep(millis / 1000.0) await asyncio.sleep(millis / 1000.0)
@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds')
@step('disconnect after {disconnect_after_millis:d} milliseconds')
@async_run_until_complete @async_run_until_complete
async def step_request_completion(context, disconnect_after_millis: int): async def step_disconnect_after(context, disconnect_after_millis: int):
seeds = await completions_seed(context, num_seeds=1) context.disconnect_after_millis = disconnect_after_millis
await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
cache_prompt=context.cache_prompt,
id_slot=context.id_slot,
disconnect_after_millis=disconnect_after_millis,
user_api_key=context.user_api_key,
temperature=context.temperature)
@step('{predicted_n:d} tokens are predicted matching {re_content}') @step('{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error):
print(f"Submitting OAI compatible completions request...") print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1), seeds = await completions_seed(context, num_seeds=1),
seeds = await completions_seed(context, num_seeds=1)
completion = await oai_chat_completions(context.prompts.pop(), completion = await oai_chat_completions(context.prompts.pop(),
seeds[0] if seeds is not None else seeds, seeds[0] if seeds is not None else seeds,
context.system_prompt, context.system_prompt,
@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error):
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None, if hasattr(context, 'user_api_key') else None,
disconnect_after_millis=context.disconnect_after_millis,
expect_api_error=expect_api_error) expect_api_error=expect_api_error)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug: if context.debug:
@ -606,6 +603,7 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format response_format=context.response_format
if hasattr(context, 'response_format') else None, if hasattr(context, 'response_format') else None,
disconnect_after_millis=context.disconnect_after_millis,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@ -1029,9 +1027,9 @@ async def request_completion(prompt,
}, },
headers=headers) as response: headers=headers) as response:
if disconnect_after_millis is not None: if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000) await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0 return 0
if expect_api_error is None or not expect_api_error: if expect_api_error is None or not expect_api_error:
assert response.status == 200 assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Access-Control-Allow-Origin'] == origin
@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt,
temperature=None, temperature=None,
model=None, model=None,
n_predict=None, n_predict=None,
disconnect_after_millis=None,
enable_streaming=None, enable_streaming=None,
response_format=None, response_format=None,
user_api_key=None, user_api_key=None,
@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt,
async with session.post(f'{base_url}{base_path}', async with session.post(f'{base_url}{base_path}',
json=payload, json=payload,
headers=headers) as response: headers=headers) as response:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
if enable_streaming: if enable_streaming:
assert response.status == 200 assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Access-Control-Allow-Origin'] == origin
@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt,
else: else:
return response.status return response.status
else: else:
assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client"
try: try:
openai.api_key = user_api_key openai.api_key = user_api_key
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'