server : add --no-context-shift option (#9607)
* server : add --no-context-shift option * small fix * Update examples/server/tests/features/embeddings.feature Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * tests : minor fix * revert usage of GGML_ASSERT * update server documentation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
f0c7b5edf8
commit
0b3bf966f4
6 changed files with 139 additions and 22 deletions
|
@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.response_format = None
|
||||
context.temperature = None
|
||||
context.lora_file = None
|
||||
context.disable_ctx_shift = False
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
|
@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
|
|||
|
||||
@step('{n_predict:d} server max tokens to predict')
|
||||
def step_server_n_predict(context, n_predict: int):
|
||||
context.n_server_predict = n_predict
|
||||
context.n_server_predict = n_predict if n_predict > 0 else None
|
||||
|
||||
|
||||
@step('{slot_save_path} as slot save path')
|
||||
|
@ -180,6 +181,9 @@ def step_server_embeddings(context):
|
|||
def step_server_metrics(context):
|
||||
context.server_metrics = True
|
||||
|
||||
@step('disable context shifting')
|
||||
def step_server_disable_ctx_shift(context):
|
||||
context.disable_ctx_shift = True
|
||||
|
||||
@step("the server is starting")
|
||||
def step_start_server(context):
|
||||
|
@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
|
|||
@step('a completion request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||
expect_api_error = api_error == 'raised'
|
||||
expect_api_error = api_error == 'raised' or api_error != 'no'
|
||||
seeds = await completions_seed(context, num_seeds=1)
|
||||
completion = await request_completion(context.prompts.pop(),
|
||||
seeds[0] if seeds is not None else seeds,
|
||||
|
@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
|||
context.tasks_result.append(completion)
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
if api_error == 'raised':
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
elif api_error.isdigit():
|
||||
api_error_code = int(api_error)
|
||||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||
|
||||
|
||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||
|
@ -645,6 +652,9 @@ def step_assert_embeddings(context):
|
|||
for embedding in context.embeddings:
|
||||
assert_embeddings(embedding)
|
||||
|
||||
@step('embeddings request with {api_error_code:d} api error')
|
||||
def step_assert_embeddings(context, api_error_code: int):
|
||||
assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
|
||||
|
||||
@step('an OAI compatible embeddings computation request for')
|
||||
@async_run_until_complete
|
||||
|
@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
|
|||
return completion_response
|
||||
|
||||
|
||||
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
|
||||
async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
|
||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||
async with session.post(f'{base_url}/embedding',
|
||||
json={
|
||||
"content": content,
|
||||
}) as response:
|
||||
assert response.status == 200
|
||||
response_json = await response.json()
|
||||
return [response_json['embedding']]
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
return [response_json['embedding']]
|
||||
else:
|
||||
return response.status
|
||||
|
||||
|
||||
async def request_oai_embeddings(input, seed,
|
||||
|
@ -1372,6 +1384,8 @@ def start_server_background(context):
|
|||
server_args.append('--verbose')
|
||||
if context.lora_file:
|
||||
server_args.extend(['--lora', context.lora_file])
|
||||
if context.disable_ctx_shift:
|
||||
server_args.extend(['--no-context-shift'])
|
||||
|
||||
args = [str(arg) for arg in [context.server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue