server
: fix cancel tests
This commit is contained in:
parent
88c9b5497a
commit
3f96ab04a6
3 changed files with 53 additions and 32 deletions
|
@ -2349,6 +2349,7 @@ struct server_context {
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
if (params.testing_sampler_delay_millis > 0) {
|
if (params.testing_sampler_delay_millis > 0) {
|
||||||
|
LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis);
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
|
std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis));
|
||||||
}
|
}
|
||||||
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||||
|
|
|
@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests
|
||||||
|
|
||||||
Background: Server startup
|
Background: Server startup
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on localhost:8080
|
||||||
And 500 milliseconds delay in sampler for testing
|
|
||||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
And a model file test-model.gguf
|
And a model file test-model.gguf
|
||||||
And a model alias tinyllama-2
|
And a model alias tinyllama-2
|
||||||
|
@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests
|
||||||
# KV Cache corresponds to the total amount of tokens
|
# KV Cache corresponds to the total amount of tokens
|
||||||
# that can be stored across all independent sequences: #4130
|
# that can be stored across all independent sequences: #4130
|
||||||
# see --ctx-size and #5568
|
# see --ctx-size and #5568
|
||||||
And 256 KV cache size
|
And 512 KV cache size
|
||||||
And 32 as batch size
|
And 32 as batch size
|
||||||
And 1 slots
|
And 2 slots
|
||||||
And 64 server max tokens to predict
|
And 64 server max tokens to predict
|
||||||
|
And prometheus compatible metrics exposed
|
||||||
|
And 300 milliseconds delay in sampler for testing
|
||||||
|
And no warmup
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
# Then the server is healthy with timeout 10 seconds
|
||||||
|
|
||||||
# Scenario: Health
|
|
||||||
# Then the server is ready
|
|
||||||
# And all slots are idle
|
|
||||||
|
|
||||||
@wip
|
Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming <enable_streaming>)
|
||||||
Scenario Outline: Cancelling completion request frees up slot
|
Given a model llama-2
|
||||||
Given a prompt:
|
And a user prompt Once upon a time
|
||||||
"""
|
And a system prompt You tell lengthy stories
|
||||||
Once upon
|
|
||||||
"""
|
|
||||||
And 256 max tokens to predict
|
And 256 max tokens to predict
|
||||||
And 256 server max tokens to predict
|
And 256 server max tokens to predict
|
||||||
And streaming is <enable_streaming>
|
And streaming is <enable_streaming>
|
||||||
And a completion request cancelled after 100 milliseconds
|
And disconnect after 100 milliseconds
|
||||||
# And wait for 50 milliseconds
|
Given concurrent OAI completions requests
|
||||||
|
And wait for 700 milliseconds
|
||||||
|
Then all slots are idle
|
||||||
|
|
||||||
|
Examples: Prompts
|
||||||
|
| enable_streaming |
|
||||||
|
| disabled |
|
||||||
|
| enabled |
|
||||||
|
|
||||||
|
|
||||||
|
Scenario Outline: Cancelling a completion request frees up slot (streaming <enable_streaming>)
|
||||||
|
Given a model llama-2
|
||||||
|
Given a prompt Once upon a time
|
||||||
|
And 256 max tokens to predict
|
||||||
|
And 256 server max tokens to predict
|
||||||
|
And streaming is <enable_streaming>
|
||||||
|
And disconnect after 100 milliseconds
|
||||||
|
Given a completion request with no api error
|
||||||
|
And wait for 700 milliseconds
|
||||||
Then all slots are idle
|
Then all slots are idle
|
||||||
|
|
||||||
Examples: Prompts
|
Examples: Prompts
|
||||||
|
|
|
@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.testing_sampler_delay_millis = None
|
context.testing_sampler_delay_millis = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
context.disconnect_after_millis = None
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
|
@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
n_predict=context.n_predict,
|
n_predict=context.n_predict,
|
||||||
cache_prompt=context.cache_prompt,
|
cache_prompt=context.cache_prompt,
|
||||||
id_slot=context.id_slot,
|
id_slot=context.id_slot,
|
||||||
|
disconnect_after_millis=context.disconnect_after_millis,
|
||||||
expect_api_error=expect_api_error,
|
expect_api_error=expect_api_error,
|
||||||
user_api_key=context.user_api_key,
|
user_api_key=context.user_api_key,
|
||||||
temperature=context.temperature)
|
temperature=context.temperature)
|
||||||
|
@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
async def step_request_completion(context, millis: int):
|
async def step_request_completion(context, millis: int):
|
||||||
await asyncio.sleep(millis / 1000.0)
|
await asyncio.sleep(millis / 1000.0)
|
||||||
|
|
||||||
@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds')
|
|
||||||
|
@step('disconnect after {disconnect_after_millis:d} milliseconds')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_request_completion(context, disconnect_after_millis: int):
|
async def step_disconnect_after(context, disconnect_after_millis: int):
|
||||||
seeds = await completions_seed(context, num_seeds=1)
|
context.disconnect_after_millis = disconnect_after_millis
|
||||||
await request_completion(context.prompts.pop(),
|
|
||||||
seeds[0] if seeds is not None else seeds,
|
|
||||||
context.base_url,
|
|
||||||
debug=context.debug,
|
|
||||||
n_predict=context.n_predict,
|
|
||||||
cache_prompt=context.cache_prompt,
|
|
||||||
id_slot=context.id_slot,
|
|
||||||
disconnect_after_millis=disconnect_after_millis,
|
|
||||||
user_api_key=context.user_api_key,
|
|
||||||
temperature=context.temperature)
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
|
@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
print(f"Submitting OAI compatible completions request...")
|
print(f"Submitting OAI compatible completions request...")
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised'
|
||||||
seeds = await completions_seed(context, num_seeds=1),
|
seeds = await completions_seed(context, num_seeds=1),
|
||||||
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await oai_chat_completions(context.prompts.pop(),
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
seeds[0] if seeds is not None else seeds,
|
seeds[0] if seeds is not None else seeds,
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
|
@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None,
|
if hasattr(context, 'user_api_key') else None,
|
||||||
|
|
||||||
|
disconnect_after_millis=context.disconnect_after_millis,
|
||||||
|
|
||||||
expect_api_error=expect_api_error)
|
expect_api_error=expect_api_error)
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
|
@ -606,6 +603,7 @@ async def step_oai_chat_completions(context):
|
||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
|
disconnect_after_millis=context.disconnect_after_millis,
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
@ -1029,7 +1027,7 @@ async def request_completion(prompt,
|
||||||
},
|
},
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if disconnect_after_millis is not None:
|
if disconnect_after_millis is not None:
|
||||||
await asyncio.sleep(disconnect_after_millis / 1000)
|
await asyncio.sleep(disconnect_after_millis / 1000.0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if expect_api_error is None or not expect_api_error:
|
if expect_api_error is None or not expect_api_error:
|
||||||
|
@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
model=None,
|
model=None,
|
||||||
n_predict=None,
|
n_predict=None,
|
||||||
|
disconnect_after_millis=None,
|
||||||
enable_streaming=None,
|
enable_streaming=None,
|
||||||
response_format=None,
|
response_format=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
|
@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt,
|
||||||
async with session.post(f'{base_url}{base_path}',
|
async with session.post(f'{base_url}{base_path}',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
|
if disconnect_after_millis is not None:
|
||||||
|
await asyncio.sleep(disconnect_after_millis / 1000.0)
|
||||||
|
return 0
|
||||||
|
|
||||||
if enable_streaming:
|
if enable_streaming:
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
else:
|
else:
|
||||||
return response.status
|
return response.status
|
||||||
else:
|
else:
|
||||||
|
assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client"
|
||||||
try:
|
try:
|
||||||
openai.api_key = user_api_key
|
openai.api_key = user_api_key
|
||||||
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
|
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue