server: fix cancel tests

This commit is contained in:
ochafik 2024-09-29 19:12:59 +01:00
parent 88c9b5497a
commit 3f96ab04a6
3 changed files with 53 additions and 32 deletions

View file

@ -2349,6 +2349,7 @@ struct server_context {
completion_token_output result;
if (params.testing_sampler_delay_millis > 0) {
LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis);
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
}
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);

View file

@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests
Background: Server startup
Given a server listening on localhost:8080
And 500 milliseconds delay in sampler for testing
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests
# KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130
# see --ctx-size and #5568
And 256 KV cache size
And 512 KV cache size
And 32 as batch size
And 1 slots
And 2 slots
And 64 server max tokens to predict
And prometheus compatible metrics exposed
And 300 milliseconds delay in sampler for testing
And no warmup
Then the server is starting
Then the server is healthy
# Then the server is healthy with timeout 10 seconds
# Scenario: Health
# Then the server is ready
# And all slots are idle
@wip
Scenario Outline: Cancelling completion request frees up slot
Given a prompt:
"""
Once upon
"""
Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming <enable_streaming>)
Given a model llama-2
And a user prompt Once upon a time
And a system prompt You tell lengthy stories
And 256 max tokens to predict
And 256 server max tokens to predict
And streaming is <enable_streaming>
And a completion request cancelled after 100 milliseconds
# And wait for 50 milliseconds
And disconnect after 100 milliseconds
Given concurrent OAI completions requests
And wait for 700 milliseconds
Then all slots are idle
Examples: Prompts
| enable_streaming |
| disabled |
| enabled |
Scenario Outline: Cancelling a completion request frees up slot (streaming <enable_streaming>)
Given a model llama-2
Given a prompt Once upon a time
And 256 max tokens to predict
And 256 server max tokens to predict
And streaming is <enable_streaming>
And disconnect after 100 milliseconds
Given a completion request with no api error
And wait for 700 milliseconds
Then all slots are idle
Examples: Prompts

View file

@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.lora_file = None
context.testing_sampler_delay_millis = None
context.disable_ctx_shift = False
context.disconnect_after_millis = None
context.tasks_result = []
context.concurrent_tasks = []
@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
n_predict=context.n_predict,
cache_prompt=context.cache_prompt,
id_slot=context.id_slot,
disconnect_after_millis=context.disconnect_after_millis,
expect_api_error=expect_api_error,
user_api_key=context.user_api_key,
temperature=context.temperature)
@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
async def step_request_completion(context, millis: int):
await asyncio.sleep(millis / 1000.0)
@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds')
@step('disconnect after {disconnect_after_millis:d} milliseconds')
@async_run_until_complete
async def step_request_completion(context, disconnect_after_millis: int):
seeds = await completions_seed(context, num_seeds=1)
await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
cache_prompt=context.cache_prompt,
id_slot=context.id_slot,
disconnect_after_millis=disconnect_after_millis,
user_api_key=context.user_api_key,
temperature=context.temperature)
async def step_disconnect_after(context, disconnect_after_millis: int):
context.disconnect_after_millis = disconnect_after_millis
@step('{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error):
print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1),
seeds = await completions_seed(context, num_seeds=1)
completion = await oai_chat_completions(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
context.system_prompt,
@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error):
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,
disconnect_after_millis=context.disconnect_after_millis,
expect_api_error=expect_api_error)
context.tasks_result.append(completion)
if context.debug:
@ -606,6 +603,7 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
disconnect_after_millis=context.disconnect_after_millis,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@ -1029,7 +1027,7 @@ async def request_completion(prompt,
},
headers=headers) as response:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000)
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
if expect_api_error is None or not expect_api_error:
@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt,
temperature=None,
model=None,
n_predict=None,
disconnect_after_millis=None,
enable_streaming=None,
response_format=None,
user_api_key=None,
@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt,
async with session.post(f'{base_url}{base_path}',
json=payload,
headers=headers) as response:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
if enable_streaming:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt,
else:
return response.status
else:
assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client"
try:
openai.api_key = user_api_key
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'