Merge e6a5a6c6ac
into b549a1bbef
This commit is contained in:
commit
7c63eb09d2
4 changed files with 177 additions and 8 deletions
|
@ -113,8 +113,10 @@ struct slot_params {
|
|||
bool stream = true;
|
||||
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_truncate = 0; // number of tokens after n_keep that will be discarded when the prompt is bigger than the context
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
@ -159,6 +161,7 @@ struct server_slot {
|
|||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool has_next_token = true;
|
||||
bool shifted = false;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
|
@ -194,6 +197,7 @@ struct server_slot {
|
|||
void reset() {
|
||||
n_prompt_tokens = 0;
|
||||
generated_text = "";
|
||||
shifted = false;
|
||||
truncated = false;
|
||||
stopped_eos = false;
|
||||
stopped_word = false;
|
||||
|
@ -919,6 +923,7 @@ struct server_context {
|
|||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
@ -1347,6 +1352,7 @@ struct server_context {
|
|||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"n_truncate", slot.params.n_truncate},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
|
@ -1434,6 +1440,7 @@ struct server_context {
|
|||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
{"stopped_word", slot.stopped_word},
|
||||
|
@ -1889,6 +1896,7 @@ struct server_context {
|
|||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated}
|
||||
});
|
||||
|
||||
|
@ -1962,7 +1970,7 @@ struct server_context {
|
|||
|
||||
slot.n_past -= n_discard;
|
||||
|
||||
slot.truncated = true;
|
||||
slot.shifted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1997,6 +2005,7 @@ struct server_context {
|
|||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated}
|
||||
});
|
||||
}
|
||||
|
@ -2103,9 +2112,11 @@ struct server_context {
|
|||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
const int n_truncate = slot.params.n_truncate;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size;
|
||||
const int erased_tokens = n_truncate ? n_truncate : half_ctx;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
|
@ -2113,7 +2124,7 @@ struct server_context {
|
|||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_tokens,
|
||||
prompt_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
@ -2121,6 +2132,17 @@ struct server_context {
|
|||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) {
|
||||
LOG_ERROR("prompt - n_truncate > n_ctx", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_truncate", n_truncate},
|
||||
});
|
||||
}
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
|
@ -2131,6 +2153,34 @@ struct server_context {
|
|||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
if (n_truncate && slot.params.cache_prompt) {
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) {
|
||||
slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i];
|
||||
}
|
||||
|
||||
const int new_cache_size = slot.cache_tokens.size() - n_truncate;
|
||||
if (new_cache_size >= 0) {
|
||||
slot.cache_tokens.resize(new_cache_size);
|
||||
|
||||
LOG_VERBOSE("cache tokens shift", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_truncate", n_truncate},
|
||||
{"new_cache_size", new_cache_size},
|
||||
{"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())},
|
||||
});
|
||||
} else {
|
||||
LOG_ERROR("n_truncate needs to be used with cache_prompt", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
|
@ -2145,6 +2195,12 @@ struct server_context {
|
|||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
LOG_INFO("[cache_tokens, prompt_tokens]", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "common_part", slot.n_past}
|
||||
});
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
|
@ -2175,6 +2231,30 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// shift KV cache if needed
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
const int n_truncate = slot.params.n_truncate;
|
||||
if (n_truncate && slot.params.cache_prompt && slot.truncated) {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate);
|
||||
|
||||
LOG_INFO("kv cache rm", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "p0", n_keep },
|
||||
{ "p1", n_keep + n_truncate }
|
||||
});
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate);
|
||||
|
||||
LOG_INFO("kv cache add", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "p0", n_keep + n_truncate },
|
||||
{ "p1", slot.n_past },
|
||||
{ "delta", -n_truncate }
|
||||
});
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
|
|
|
@ -51,6 +51,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.n_ga = None
|
||||
context.n_ga_w = None
|
||||
context.n_predict = None
|
||||
context.n_keep = 0
|
||||
context.n_truncate = 0
|
||||
context.n_prompts = 0
|
||||
context.n_server_predict = None
|
||||
context.slot_save_path = None
|
||||
|
@ -71,6 +73,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.response_format = None
|
||||
context.temperature = None
|
||||
|
||||
context.stop_string = []
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
@ -241,19 +244,35 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
|
|||
@step('a completion request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||
await make_completion_request(context, api_error)
|
||||
|
||||
|
||||
@step('an ongoing completion request')
|
||||
@async_run_until_complete
|
||||
async def step_request_ongoing_completion(context):
|
||||
await make_completion_request(context, '', True)
|
||||
|
||||
|
||||
async def make_completion_request(context, api_error: Literal['raised'] | str, ongoing=False):
|
||||
expect_api_error = api_error == 'raised'
|
||||
seeds = await completions_seed(context, num_seeds=1)
|
||||
completion = await request_completion(context.prompts.pop(),
|
||||
prompt = context.prompts[-1] if ongoing else context.prompts.pop()
|
||||
completion = await request_completion(prompt,
|
||||
seeds[0] if seeds is not None else seeds,
|
||||
context.base_url,
|
||||
debug=context.debug,
|
||||
n_predict=context.n_predict,
|
||||
n_keep=context.n_keep,
|
||||
n_truncate=context.n_truncate,
|
||||
cache_prompt=context.cache_prompt,
|
||||
stop_string=context.stop_string,
|
||||
id_slot=context.id_slot,
|
||||
expect_api_error=expect_api_error,
|
||||
user_api_key=context.user_api_key,
|
||||
temperature=context.temperature)
|
||||
context.tasks_result.append(completion)
|
||||
if ongoing and isinstance(completion, dict) and not expect_api_error:
|
||||
context.prompts[-1] += completion['content']
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
|
@ -336,6 +355,16 @@ def step_max_tokens(context, max_tokens):
|
|||
context.n_predict = max_tokens
|
||||
|
||||
|
||||
@step('{n_keep:d} tokens to keep')
|
||||
def step_keep_tokens(context, n_keep):
|
||||
context.n_keep = n_keep
|
||||
|
||||
|
||||
@step('{n_truncate:d} tokens to truncate')
|
||||
def step_truncate_tokens(context, n_truncate):
|
||||
context.n_truncate = n_truncate
|
||||
|
||||
|
||||
@step('a response format {response_format}')
|
||||
def step_response_format(context, response_format):
|
||||
context.response_format = json.loads(response_format)
|
||||
|
@ -399,6 +428,11 @@ def step_bos_token(context, bos):
|
|||
context.bos = bos
|
||||
|
||||
|
||||
@step('a list of stop strings {stop_list}')
|
||||
def step_stop_string(context, stop_list):
|
||||
context.stop_string = json.loads(stop_list)
|
||||
|
||||
|
||||
@step('a prefix prompt')
|
||||
def step_prompt_prefix(context):
|
||||
context.prompt_prefix = context_text(context)
|
||||
|
@ -510,6 +544,11 @@ def step_many_prompts(context, num_prompts, prompt, seed):
|
|||
context.n_prompts = len(context.prompts)
|
||||
|
||||
|
||||
@step('an ongoing prompt')
|
||||
def step_a_ongoing_prompt(context):
|
||||
context.prompts[-1] += context_text(context)
|
||||
|
||||
|
||||
@step('concurrent completion requests')
|
||||
@async_run_until_complete()
|
||||
async def step_concurrent_completion_requests(context):
|
||||
|
@ -861,7 +900,10 @@ async def request_completion(prompt,
|
|||
prompt_prefix=None,
|
||||
prompt_suffix=None,
|
||||
n_predict=None,
|
||||
n_keep=0,
|
||||
n_truncate=0,
|
||||
cache_prompt=False,
|
||||
stop_string=None,
|
||||
id_slot=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None,
|
||||
|
@ -884,7 +926,10 @@ async def request_completion(prompt,
|
|||
"prompt": prompt,
|
||||
"input_suffix": prompt_suffix,
|
||||
"n_predict": n_predict if n_predict is not None else -1,
|
||||
"n_keep": n_keep,
|
||||
"n_truncate": n_truncate,
|
||||
"cache_prompt": cache_prompt,
|
||||
"stop": stop_string if stop_string is not None else [],
|
||||
"id_slot": id_slot,
|
||||
"seed": seed if seed is not None else 42,
|
||||
"temperature": temperature if temperature is not None else 0.8,
|
||||
|
|
44
examples/server/tests/features/truncation.feature
Normal file
44
examples/server/tests/features/truncation.feature
Normal file
|
@ -0,0 +1,44 @@
|
|||
# run with: ./tests.sh --no-skipped --tags truncation
|
||||
@truncation
|
||||
@slow
|
||||
Feature: Chat truncation
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models
|
||||
And prompt caching is enabled
|
||||
And a list of stop strings ["\n"]
|
||||
And 82 tokens to keep
|
||||
And 256 KV cache size
|
||||
And 32 server max tokens to predict
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Correctly truncate the prompt when the prompt exceeds the context size
|
||||
Given a prompt:
|
||||
"""
|
||||
Continue the chat below.
|
||||
Me: Hey there, how's it going?
|
||||
You: I'm doing well, thanks for asking! How are you?
|
||||
Me: I'm doing good, just trying to get some work done. How's your day?
|
||||
You: My day has been pretty productive so far. I've been working on a new project.
|
||||
Me: That's great to hear! What's the new project you're working on?
|
||||
You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it.
|
||||
Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch?
|
||||
You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it.
|
||||
Me: That's really nice, are you happy with the progress so far?
|
||||
|
||||
"""
|
||||
And an ongoing completion request
|
||||
Then -1 tokens are predicted matching You:
|
||||
Given an ongoing prompt:
|
||||
"""
|
||||
|
||||
Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
|
||||
|
||||
"""
|
||||
And 52 tokens to truncate
|
||||
And a completion request with no api error
|
||||
Then -1 tokens are predicted matching You:
|
||||
# 28 because '\n' stop string is not pushed to the context
|
||||
And 28 prompt tokens are processed
|
|
@ -5,7 +5,7 @@ set -eu
|
|||
if [ $# -lt 1 ]
|
||||
then
|
||||
# Start @llama.cpp scenario
|
||||
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
|
||||
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey|truncation' --tags llama.cpp
|
||||
else
|
||||
behave "$@"
|
||||
fi
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue