server : add --no-context-shift option (#9607)

* server : add --no-context-shift option

* small fix

* Update examples/server/tests/features/embeddings.feature

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* tests : minor fix

* revert usage of GGML_ASSERT

* update server documentation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan Son Nguyen 2024-09-23 22:23:54 +02:00 committed by GitHub
parent f0c7b5edf8
commit 0b3bf966f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 139 additions and 22 deletions

View file

@ -0,0 +1,62 @@
@llama.cpp
@ctx_shift
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed
And 256 KV cache size
And 32 as batch size
And 2 slots
Scenario: Inference with context shift
And 64 server max tokens to predict
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated
And 109 prompt tokens are processed
Scenario Outline: Inference without context shift
And <n_predict> server max tokens to predict
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Hi how are you
"""
And a completion request with no api error
Then <n_token_output> tokens are predicted matching twind|Anna
And the completion is <truncated> truncated
And 8 prompt tokens are processed
Examples:
| n_predict | n_token_output | truncated |
| 64 | 64 | not |
| -1 | 120 | |
Scenario: Inference without context shift (expected error: prompt too long)
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with 400 api error

View file

@ -10,11 +10,11 @@ Feature: llama.cpp server
And 42 as server seed
And 2 slots
# the bert-bge-small model has context size of 512
# since the generated prompts are as big as the batch size, we need to set the batch size to 512
# since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
And 512 as batch size
And 512 as ubatch size
And 2048 KV cache size
And 128 as batch size
And 128 as ubatch size
And 512 KV cache size
And embeddings extraction
Then the server is starting
Then the server is healthy
@ -26,6 +26,20 @@ Feature: llama.cpp server
"""
Then embeddings are generated
Scenario: Embedding (error: prompt too long)
When embeddings are computed for:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And embeddings request with 500 api error
Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:

View file

@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.lora_file = None
context.disable_ctx_shift = False
context.tasks_result = []
context.concurrent_tasks = []
@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
@step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict
context.n_server_predict = n_predict if n_predict > 0 else None
@step('{slot_save_path} as slot save path')
@ -180,6 +181,9 @@ def step_server_embeddings(context):
def step_server_metrics(context):
context.server_metrics = True
@step('disable context shifting')
def step_server_disable_ctx_shift(context):
context.disable_ctx_shift = True
@step("the server is starting")
def step_start_server(context):
@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised'
expect_api_error = api_error == 'raised' or api_error != 'no'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
if api_error == 'raised':
assert completion == 401, f"completion must be an 401 status code: {completion}"
elif api_error.isdigit():
api_error_code = int(api_error)
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('{predicted_n:d} tokens are predicted matching {re_content}')
@ -645,6 +652,9 @@ def step_assert_embeddings(context):
for embedding in context.embeddings:
assert_embeddings(embedding)
@step('embeddings request with {api_error_code:d} api error')
def step_assert_embeddings(context, api_error_code: int):
assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
@step('an OAI compatible embeddings computation request for')
@async_run_until_complete
@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
return completion_response
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{base_url}/embedding',
json={
"content": content,
}) as response:
assert response.status == 200
response_json = await response.json()
return [response_json['embedding']]
if response.status == 200:
response_json = await response.json()
return [response_json['embedding']]
else:
return response.status
async def request_oai_embeddings(input, seed,
@ -1372,6 +1384,8 @@ def start_server_background(context):
server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
if context.disable_ctx_shift:
server_args.extend(['--no-context-shift'])
args = [str(arg) for arg in [context.server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")