add server test case for slot save restore

This commit is contained in:
Jan Boon 2024-03-30 06:03:41 +08:00
parent f2e41b3239
commit 92c468105b
2 changed files with 108 additions and 0 deletions

View file

@ -0,0 +1,48 @@
@llama.cpp
@server
Feature: llama.cpp server slot management
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And prompt caching is enabled
And 2 slots
And . as slot save path
And 2048 KV cache size
And 42 as server seed
And 24 max tokens to predict
Then the server is starting
Then the server is healthy
Scenario: Save and Restore Slot
Given a user prompt "What is the capital of France?"
And using slot id 1
And a completion request with no api error
Then 24 tokens are predicted matching Lily
And 22 prompt tokens are processed
When the slot 1 is saved with filename "slot1.bin"
Then the server responds with status code 200
Given a user prompt "What is the capital of Germany?"
And a completion request with no api error
Then 24 tokens are predicted matching Thank
And 7 prompt tokens are processed
When the slot 2 is restored with filename "slot1.bin"
Then the server responds with status code 200
Given a user prompt "What is the capital of France?"
And using slot id 2
And a completion request with no api error
Then 24 tokens are predicted matching Lily
And 1 prompt tokens are processed
Scenario: Erase Slot
Given a user prompt "What is the capital of France?"
And using slot id 1
And a completion request with no api error
Then 24 tokens are predicted matching Lily
And 22 prompt tokens are processed
When the slot 1 is erased
Then the server responds with status code 200
Given a user prompt "What is the capital of France?"
And a completion request with no api error
Then 24 tokens are predicted matching Lily
And 22 prompt tokens are processed

View file

@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port):
context.n_predict = None context.n_predict = None
context.n_prompts = 0 context.n_prompts = 0
context.n_server_predict = None context.n_server_predict = None
context.slot_save_path = None
context.id_slot = None
context.cache_prompt = None
context.n_slots = None context.n_slots = None
context.prompt_prefix = None context.prompt_prefix = None
context.prompt_suffix = None context.prompt_suffix = None
@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict):
context.n_server_predict = n_predict context.n_server_predict = n_predict
@step('{slot_save_path} as slot save path')
def step_slot_save_path(context, slot_save_path):
context.slot_save_path = slot_save_path
@step('using slot id {id_slot:d}')
def step_id_slot(context, id_slot):
context.id_slot = id_slot
@step('prompt caching is enabled')
def step_enable_prompt_cache(context):
context.cache_prompt = True
@step('continuous batching') @step('continuous batching')
def step_server_continuous_batching(context): def step_server_continuous_batching(context):
context.server_continuous_batching = True context.server_continuous_batching = True
@ -212,6 +230,8 @@ async def step_request_completion(context, api_error):
context.base_url, context.base_url,
debug=context.debug, debug=context.debug,
n_predict=context.n_predict, n_predict=context.n_predict,
cache_prompt=context.cache_prompt,
id_slot=context.id_slot,
seed=await completions_seed(context), seed=await completions_seed(context),
expect_api_error=expect_api_error, expect_api_error=expect_api_error,
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@step('the slot {slot_id:d} is saved with filename "{filename}"')
@async_run_until_complete
async def step_save_slot(context, slot_id, filename):
async with aiohttp.ClientSession() as session:
async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
json={"filename": filename},
headers={"Content-Type": "application/json"}) as response:
context.response = response
@step('the slot {slot_id:d} is restored with filename "{filename}"')
@async_run_until_complete
async def step_restore_slot(context, slot_id, filename):
async with aiohttp.ClientSession() as session:
async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
json={"filename": filename},
headers={"Content-Type": "application/json"}) as response:
context.response = response
@step('the slot {slot_id:d} is erased')
@async_run_until_complete
async def step_erase_slot(context, slot_id):
async with aiohttp.ClientSession() as session:
async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
headers={"Content-Type": "application/json"}) as response:
context.response = response
@step('the server responds with status code {status_code:d}')
def step_server_responds_with_status_code(context, status_code):
assert context.response.status == status_code
async def request_completion(prompt, async def request_completion(prompt,
base_url, base_url,
debug=False, debug=False,
prompt_prefix=None, prompt_prefix=None,
prompt_suffix=None, prompt_suffix=None,
n_predict=None, n_predict=None,
cache_prompt=False,
id_slot=None,
seed=None, seed=None,
expect_api_error=None, expect_api_error=None,
user_api_key=None): user_api_key=None):
@ -738,6 +794,8 @@ async def request_completion(prompt,
"prompt": prompt, "prompt": prompt,
"input_suffix": prompt_suffix, "input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1, "n_predict": n_predict if n_predict is not None else -1,
"cache_prompt": cache_prompt,
"id_slot": id_slot,
"seed": seed if seed is not None else 42 "seed": seed if seed is not None else 42
}, },
headers=headers, headers=headers,
@ -1104,6 +1162,8 @@ def start_server_background(context):
server_args.extend(['--parallel', context.n_slots]) server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict: if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict]) server_args.extend(['--n-predict', context.n_server_predict])
if context.slot_save_path:
server_args.extend(['--slot-save-path', context.slot_save_path])
if context.server_api_key: if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key]) server_args.extend(['--api-key', context.server_api_key])
if context.n_ga: if context.n_ga: