server : refactor slot input data, move tokenizer to HTTP thread (#10023)
* server : refactor slot input data, move tokenizer to HTTP thread * move prompt_tokens.empty() check * fix incorrect if branch * fix infinite generation loop * bring back infill validation * add infill test * try fixing format_infill * fix test * remove redundant code * rename completion to inference * update docs * use llama_tokens everywhere
This commit is contained in:
parent
40f2555797
commit
958367bf53
5 changed files with 468 additions and 348 deletions
|
@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.lora_file = None
|
||||
context.disable_ctx_shift = False
|
||||
|
||||
# infill
|
||||
context.infill_input_extra = None
|
||||
context.infill_input_suffix = ''
|
||||
context.infill_input_prefix = ''
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
|||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||
|
||||
|
||||
@step('an infill request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||
if api_error != 'no':
|
||||
raise ValueError(f'api_error={api_error} is not yet implemented')
|
||||
payload = {
|
||||
"prompt": context.prompts[0],
|
||||
"input_suffix": context.infill_input_suffix,
|
||||
"input_prefix": context.infill_input_prefix,
|
||||
"n_predict": context.n_predict,
|
||||
"seed": context.seed,
|
||||
"temperature": context.temperature,
|
||||
}
|
||||
if context.infill_input_extra is not None:
|
||||
payload['input_extra'] = context.infill_input_extra
|
||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||
async with session.post(f'{context.base_url}/infill',
|
||||
json=payload) as response:
|
||||
assert response.status == 200
|
||||
context.tasks_result = [await response.json()]
|
||||
|
||||
|
||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
context.completion = context.tasks_result.pop()
|
||||
|
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
|
|||
context.n_prompts = len(context.prompts)
|
||||
|
||||
|
||||
# TODO: allow this to be repeated
|
||||
@step('an infill input extra {filename} {text}')
|
||||
def step_infill_input_extra(context, filename, text):
|
||||
if filename == 'none':
|
||||
context.infill_input_extra = None
|
||||
else:
|
||||
context.infill_input_extra = [{'filename': filename, 'text': text}]
|
||||
|
||||
|
||||
@step('an infill input suffix {text}')
|
||||
def step_infill_input_suffix(context, text):
|
||||
context.infill_input_suffix = text
|
||||
|
||||
|
||||
@step('an infill input prefix {text}')
|
||||
def step_infill_input_prefix(context, text):
|
||||
context.infill_input_prefix = text
|
||||
|
||||
|
||||
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||
if context.seed is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue