server
: test request cancellation (WIP)
This commit is contained in:
parent
4dcb3ea943
commit
5f00747a90
2 changed files with 69 additions and 2 deletions
43
examples/server/tests/features/cancel.feature
Normal file
43
examples/server/tests/features/cancel.feature
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
@llama.cpp
|
||||||
|
@server
|
||||||
|
Feature: Cancellation of llama.cpp server requests
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And 500 milliseconds delay in sampler for testing
|
||||||
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model.gguf
|
||||||
|
And a model alias tinyllama-2
|
||||||
|
And BOS token is 1
|
||||||
|
And 42 as server seed
|
||||||
|
# KV Cache corresponds to the total amount of tokens
|
||||||
|
# that can be stored across all independent sequences: #4130
|
||||||
|
# see --ctx-size and #5568
|
||||||
|
And 256 KV cache size
|
||||||
|
And 32 as batch size
|
||||||
|
And 1 slots
|
||||||
|
And 64 server max tokens to predict
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
# Scenario: Health
|
||||||
|
# Then the server is ready
|
||||||
|
# And all slots are idle
|
||||||
|
|
||||||
|
@wip
|
||||||
|
Scenario Outline: Cancelling completion request frees up slot
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Once upon
|
||||||
|
"""
|
||||||
|
And 256 max tokens to predict
|
||||||
|
And 256 server max tokens to predict
|
||||||
|
And streaming is <enable_streaming>
|
||||||
|
And a completion request cancelled after 100 milliseconds
|
||||||
|
# And wait for 50 milliseconds
|
||||||
|
Then all slots are idle
|
||||||
|
|
||||||
|
Examples: Prompts
|
||||||
|
| enable_streaming |
|
||||||
|
| disabled |
|
||||||
|
| enabled |
|
|
@ -291,6 +291,25 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
api_error_code = int(api_error)
|
api_error_code = int(api_error)
|
||||||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
@step('wait for {millis:d} milliseconds')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_request_completion(context, millis: int):
|
||||||
|
await asyncio.sleep(millis / 1000.0)
|
||||||
|
|
||||||
|
@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_request_completion(context, disconnect_after_millis: int):
|
||||||
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
|
await request_completion(context.prompts.pop(),
|
||||||
|
seeds[0] if seeds is not None else seeds,
|
||||||
|
context.base_url,
|
||||||
|
debug=context.debug,
|
||||||
|
n_predict=context.n_predict,
|
||||||
|
cache_prompt=context.cache_prompt,
|
||||||
|
id_slot=context.id_slot,
|
||||||
|
disconnect_after_millis=disconnect_after_millis,
|
||||||
|
user_api_key=context.user_api_key,
|
||||||
|
temperature=context.temperature)
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
|
@ -982,9 +1001,10 @@ async def request_completion(prompt,
|
||||||
id_slot=None,
|
id_slot=None,
|
||||||
expect_api_error=None,
|
expect_api_error=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
|
disconnect_after_millis=None,
|
||||||
temperature=None) -> int | dict[str, Any]:
|
temperature=None) -> int | dict[str, Any]:
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Sending completion request: {prompt}")
|
print(f"Sending completion request: {prompt} with n_predict={n_predict}")
|
||||||
origin = "my.super.domain"
|
origin = "my.super.domain"
|
||||||
headers = {
|
headers = {
|
||||||
'Origin': origin
|
'Origin': origin
|
||||||
|
@ -1008,6 +1028,10 @@ async def request_completion(prompt,
|
||||||
"n_probs": 2,
|
"n_probs": 2,
|
||||||
},
|
},
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
|
if disconnect_after_millis is not None:
|
||||||
|
await asyncio.sleep(disconnect_after_millis / 1000)
|
||||||
|
return 0
|
||||||
|
|
||||||
if expect_api_error is None or not expect_api_error:
|
if expect_api_error is None or not expect_api_error:
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||||
|
@ -1352,7 +1376,7 @@ async def request_slots_status(context, expected_slots):
|
||||||
|
|
||||||
|
|
||||||
def assert_slots_status(slots, expected_slots):
|
def assert_slots_status(slots, expected_slots):
|
||||||
assert len(slots) == len(expected_slots)
|
assert len(slots) == len(expected_slots), f'invalid number of slots: {len(slots)} (actual) != {len(expected_slots)} (expected)'
|
||||||
for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
|
for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
|
||||||
for key in expected:
|
for key in expected:
|
||||||
assert expected[key] == slot[key], (f"invalid slot {slot_id}"
|
assert expected[key] == slot[key], (f"invalid slot {slot_id}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue