llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
87fb5b4234
commit
beea6e1b16
11 changed files with 1086 additions and 31 deletions
|
@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
context.n_predict = None
|
||||
context.n_prompts = 0
|
||||
context.n_server_predict = None
|
||||
context.slot_save_path = None
|
||||
context.id_slot = None
|
||||
context.cache_prompt = None
|
||||
context.n_slots = None
|
||||
context.prompt_prefix = None
|
||||
context.prompt_suffix = None
|
||||
|
@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict):
|
|||
context.n_server_predict = n_predict
|
||||
|
||||
|
||||
@step('{slot_save_path} as slot save path')
|
||||
def step_slot_save_path(context, slot_save_path):
|
||||
context.slot_save_path = slot_save_path
|
||||
|
||||
|
||||
@step('using slot id {id_slot:d}')
|
||||
def step_id_slot(context, id_slot):
|
||||
context.id_slot = id_slot
|
||||
|
||||
|
||||
@step('prompt caching is enabled')
|
||||
def step_enable_prompt_cache(context):
|
||||
context.cache_prompt = True
|
||||
|
||||
|
||||
@step('continuous batching')
|
||||
def step_server_continuous_batching(context):
|
||||
context.server_continuous_batching = True
|
||||
|
@ -212,6 +230,8 @@ async def step_request_completion(context, api_error):
|
|||
context.base_url,
|
||||
debug=context.debug,
|
||||
n_predict=context.n_predict,
|
||||
cache_prompt=context.cache_prompt,
|
||||
id_slot=context.id_slot,
|
||||
seed=await completions_seed(context),
|
||||
expect_api_error=expect_api_error,
|
||||
user_api_key=context.user_api_key)
|
||||
|
@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@step('the slot {slot_id:d} is saved with filename "{filename}"')
|
||||
@async_run_until_complete
|
||||
async def step_save_slot(context, slot_id, filename):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
|
||||
json={"filename": filename},
|
||||
headers={"Content-Type": "application/json"}) as response:
|
||||
context.response = response
|
||||
|
||||
|
||||
@step('the slot {slot_id:d} is restored with filename "{filename}"')
|
||||
@async_run_until_complete
|
||||
async def step_restore_slot(context, slot_id, filename):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
|
||||
json={"filename": filename},
|
||||
headers={"Content-Type": "application/json"}) as response:
|
||||
context.response = response
|
||||
|
||||
|
||||
@step('the slot {slot_id:d} is erased')
|
||||
@async_run_until_complete
|
||||
async def step_erase_slot(context, slot_id):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
|
||||
headers={"Content-Type": "application/json"}) as response:
|
||||
context.response = response
|
||||
|
||||
|
||||
@step('the server responds with status code {status_code:d}')
|
||||
def step_server_responds_with_status_code(context, status_code):
|
||||
assert context.response.status == status_code
|
||||
|
||||
|
||||
async def request_completion(prompt,
|
||||
base_url,
|
||||
debug=False,
|
||||
prompt_prefix=None,
|
||||
prompt_suffix=None,
|
||||
n_predict=None,
|
||||
cache_prompt=False,
|
||||
id_slot=None,
|
||||
seed=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None):
|
||||
|
@ -738,6 +794,8 @@ async def request_completion(prompt,
|
|||
"prompt": prompt,
|
||||
"input_suffix": prompt_suffix,
|
||||
"n_predict": n_predict if n_predict is not None else -1,
|
||||
"cache_prompt": cache_prompt,
|
||||
"id_slot": id_slot,
|
||||
"seed": seed if seed is not None else 42
|
||||
},
|
||||
headers=headers,
|
||||
|
@ -1104,6 +1162,8 @@ def start_server_background(context):
|
|||
server_args.extend(['--parallel', context.n_slots])
|
||||
if context.n_server_predict:
|
||||
server_args.extend(['--n-predict', context.n_server_predict])
|
||||
if context.slot_save_path:
|
||||
server_args.extend(['--slot-save-path', context.slot_save_path])
|
||||
if context.server_api_key:
|
||||
server_args.extend(['--api-key', context.server_api_key])
|
||||
if context.n_ga:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue