server : add --no-context-shift option
This commit is contained in:
parent
37f8c7b4c9
commit
c2e7945bb4
5 changed files with 127 additions and 12 deletions
|
@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--chunks"}, "N",
|
{"--chunks"}, "N",
|
||||||
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
||||||
|
|
|
@ -1180,6 +1180,15 @@ struct server_context {
|
||||||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if context shift is disabled, we stop when it reaches the context limit
|
||||||
|
if (slot.n_decoded >= slot.n_ctx) {
|
||||||
|
slot.truncated = true;
|
||||||
|
slot.stopped_limit = true;
|
||||||
|
slot.has_next_token = false;
|
||||||
|
|
||||||
|
SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (llama_token_is_eog(model, result.tok)) {
|
if (llama_token_is_eog(model, result.tok)) {
|
||||||
slot.stopped_eos = true;
|
slot.stopped_eos = true;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
@ -1480,7 +1489,7 @@ struct server_context {
|
||||||
if (result.error) {
|
if (result.error) {
|
||||||
error_handler(result.data);
|
error_handler(result.data);
|
||||||
cancel_tasks(id_tasks);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t idx = result.data["index"];
|
size_t idx = result.data["index"];
|
||||||
|
@ -1827,6 +1836,14 @@ struct server_context {
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ga_n == 1) {
|
if (slot.ga_n == 1) {
|
||||||
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||||
|
if (!params.ctx_shift) {
|
||||||
|
// this check is redundant (for good)
|
||||||
|
// we should never get here, since n_predict is already limited
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||||
|
@ -1961,6 +1978,14 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (!params.ctx_shift) {
|
||||||
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
|
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (slot.params.n_keep < 0) {
|
if (slot.params.n_keep < 0) {
|
||||||
slot.params.n_keep = slot.n_prompt_tokens;
|
slot.params.n_keep = slot.n_prompt_tokens;
|
||||||
}
|
}
|
||||||
|
|
62
examples/server/tests/features/ctx_shift.feature
Normal file
62
examples/server/tests/features/ctx_shift.feature
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
@llama.cpp
|
||||||
|
@ctx_shift
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model.gguf
|
||||||
|
And a model alias tinyllama-2
|
||||||
|
And BOS token is 1
|
||||||
|
And 42 as server seed
|
||||||
|
And 256 KV cache size
|
||||||
|
And 32 as batch size
|
||||||
|
And 2 slots
|
||||||
|
|
||||||
|
Scenario: Inference with context shift
|
||||||
|
And 64 server max tokens to predict
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
|
||||||
|
And the completion is truncated
|
||||||
|
And 109 prompt tokens are processed
|
||||||
|
|
||||||
|
Scenario Outline: Inference without context shift
|
||||||
|
And <n_predict> server max tokens to predict
|
||||||
|
And disable context shifting
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Hi how are you
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then <n_token_output> tokens are predicted matching twind|Anna
|
||||||
|
And the completion is <truncated> truncated
|
||||||
|
And 8 prompt tokens are processed
|
||||||
|
Examples:
|
||||||
|
| n_predict | n_token_output | truncated |
|
||||||
|
| 64 | 64 | not |
|
||||||
|
| -1 | 120 | |
|
||||||
|
|
||||||
|
Scenario: Inference without context shift (expected error: prompt too long)
|
||||||
|
And disable context shifting
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And a completion request with 400 api error
|
||||||
|
|
|
@ -12,9 +12,9 @@ Feature: llama.cpp server
|
||||||
# the bert-bge-small model has context size of 512
|
# the bert-bge-small model has context size of 512
|
||||||
# since the generated prompts are as big as the batch size, we need to set the batch size to 512
|
# since the generated prompts are as big as the batch size, we need to set the batch size to 512
|
||||||
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
|
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
|
||||||
And 512 as batch size
|
And 128 as batch size
|
||||||
And 512 as ubatch size
|
And 128 as ubatch size
|
||||||
And 2048 KV cache size
|
And 512 KV cache size
|
||||||
And embeddings extraction
|
And embeddings extraction
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
@ -26,6 +26,20 @@ Feature: llama.cpp server
|
||||||
"""
|
"""
|
||||||
Then embeddings are generated
|
Then embeddings are generated
|
||||||
|
|
||||||
|
Scenario: Embedding (error: prompt too long)
|
||||||
|
When embeddings are computed for:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And embeddings request with 500 api error
|
||||||
|
|
||||||
Scenario: OAI Embeddings compatibility
|
Scenario: OAI Embeddings compatibility
|
||||||
Given a model bert-bge-small
|
Given a model bert-bge-small
|
||||||
When an OAI compatible embeddings computation request for:
|
When an OAI compatible embeddings computation request for:
|
||||||
|
|
|
@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
|
context.disable_ctx_shift = False
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
|
@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
|
||||||
|
|
||||||
@step('{n_predict:d} server max tokens to predict')
|
@step('{n_predict:d} server max tokens to predict')
|
||||||
def step_server_n_predict(context, n_predict: int):
|
def step_server_n_predict(context, n_predict: int):
|
||||||
context.n_server_predict = n_predict
|
context.n_server_predict = n_predict if n_predict > 0 else None
|
||||||
|
|
||||||
|
|
||||||
@step('{slot_save_path} as slot save path')
|
@step('{slot_save_path} as slot save path')
|
||||||
|
@ -180,6 +181,9 @@ def step_server_embeddings(context):
|
||||||
def step_server_metrics(context):
|
def step_server_metrics(context):
|
||||||
context.server_metrics = True
|
context.server_metrics = True
|
||||||
|
|
||||||
|
@step('disable context shifting')
|
||||||
|
def step_server_metrics(context):
|
||||||
|
context.disable_ctx_shift = True
|
||||||
|
|
||||||
@step("the server is starting")
|
@step("the server is starting")
|
||||||
def step_start_server(context):
|
def step_start_server(context):
|
||||||
|
@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
|
||||||
@step('a completion request with {api_error} api error')
|
@step('a completion request with {api_error} api error')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised' or api_error != 'no'
|
||||||
seeds = await completions_seed(context, num_seeds=1)
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await request_completion(context.prompts.pop(),
|
completion = await request_completion(context.prompts.pop(),
|
||||||
seeds[0] if seeds is not None else seeds,
|
seeds[0] if seeds is not None else seeds,
|
||||||
|
@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"Completion response: {completion}")
|
print(f"Completion response: {completion}")
|
||||||
if expect_api_error:
|
if api_error == 'raised':
|
||||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
elif api_error.isdigit():
|
||||||
|
api_error_code = int(api_error)
|
||||||
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
|
@ -645,6 +652,9 @@ def step_assert_embeddings(context):
|
||||||
for embedding in context.embeddings:
|
for embedding in context.embeddings:
|
||||||
assert_embeddings(embedding)
|
assert_embeddings(embedding)
|
||||||
|
|
||||||
|
@step('embeddings request with {api_error_code:d} api error')
|
||||||
|
def step_assert_embeddings(context, api_error_code: int):
|
||||||
|
assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
|
||||||
|
|
||||||
@step('an OAI compatible embeddings computation request for')
|
@step('an OAI compatible embeddings computation request for')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
|
@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
|
||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
|
async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
|
||||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
"content": content,
|
"content": content,
|
||||||
}) as response:
|
}) as response:
|
||||||
assert response.status == 200
|
if response.status == 200:
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return [response_json['embedding']]
|
return [response_json['embedding']]
|
||||||
|
else:
|
||||||
|
return response.status
|
||||||
|
|
||||||
|
|
||||||
async def request_oai_embeddings(input, seed,
|
async def request_oai_embeddings(input, seed,
|
||||||
|
@ -1372,6 +1384,8 @@ def start_server_background(context):
|
||||||
server_args.append('--verbose')
|
server_args.append('--verbose')
|
||||||
if context.lora_file:
|
if context.lora_file:
|
||||||
server_args.extend(['--lora', context.lora_file])
|
server_args.extend(['--lora', context.lora_file])
|
||||||
|
if context.disable_ctx_shift:
|
||||||
|
server_args.extend(['--no-context-shift'])
|
||||||
|
|
||||||
args = [str(arg) for arg in [context.server_path, *server_args]]
|
args = [str(arg) for arg in [context.server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue