This commit is contained in:
Paulo de Castro 2024-07-12 00:49:13 -04:00 committed by GitHub
commit 7c63eb09d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 177 additions and 8 deletions

View file

@ -113,9 +113,11 @@ struct slot_params {
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_truncate = 0; // number of tokens after n_keep that will be discarded when the prompt is bigger than the context
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
@ -159,6 +161,7 @@ struct server_slot {
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool shifted = false;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@ -194,6 +197,7 @@ struct server_slot {
void reset() {
n_prompt_tokens = 0;
generated_text = "";
shifted = false;
truncated = false;
stopped_eos = false;
stopped_word = false;
@ -919,6 +923,7 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@ -1347,6 +1352,7 @@ struct server_context {
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"n_truncate", slot.params.n_truncate},
{"ignore_eos", ignore_eos},
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
@ -1434,6 +1440,7 @@ struct server_context {
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"shifted", slot.shifted},
{"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
@ -1889,6 +1896,7 @@ struct server_context {
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"shifted", slot.shifted},
{"truncated", slot.truncated}
});
@ -1962,7 +1970,7 @@ struct server_context {
slot.n_past -= n_discard;
slot.truncated = true;
slot.shifted = true;
}
}
}
@ -1997,6 +2005,7 @@ struct server_context {
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"shifted", slot.shifted},
{"truncated", slot.truncated}
});
}
@ -2103,9 +2112,11 @@ struct server_context {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_truncate = slot.params.n_truncate;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size;
const int erased_tokens = n_truncate ? n_truncate : half_ctx;
std::vector<llama_token> new_tokens(
prompt_tokens.begin(),
@ -2113,7 +2124,7 @@ struct server_context {
new_tokens.insert(
new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.begin() + slot.params.n_keep + erased_tokens,
prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
@ -2121,6 +2132,17 @@ struct server_context {
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) {
LOG_ERROR("prompt - n_truncate > n_ctx", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_truncate", n_truncate},
});
}
LOG_VERBOSE("input truncated", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
@ -2131,6 +2153,34 @@ struct server_context {
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
if (n_truncate && slot.params.cache_prompt) {
const int n_keep = slot.params.n_keep + add_bos_token;
for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i];
}
const int new_cache_size = slot.cache_tokens.size() - n_truncate;
if (new_cache_size >= 0) {
slot.cache_tokens.resize(new_cache_size);
LOG_VERBOSE("cache tokens shift", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_truncate", n_truncate},
{"new_cache_size", new_cache_size},
{"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())},
});
} else {
LOG_ERROR("n_truncate needs to be used with cache_prompt", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
});
}
}
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@ -2145,6 +2195,12 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
LOG_INFO("[cache_tokens, prompt_tokens]", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "common_part", slot.n_past}
});
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
@ -2175,6 +2231,30 @@ struct server_context {
}
}
// shift KV cache if needed
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_truncate = slot.params.n_truncate;
if (n_truncate && slot.params.cache_prompt && slot.truncated) {
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate);
LOG_INFO("kv cache rm", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "p0", n_keep },
{ "p1", n_keep + n_truncate }
});
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate);
LOG_INFO("kv cache add", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "p0", n_keep + n_truncate },
{ "p1", slot.n_past },
{ "delta", -n_truncate }
});
}
// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {

View file

@ -51,6 +51,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.n_ga = None
context.n_ga_w = None
context.n_predict = None
context.n_keep = 0
context.n_truncate = 0
context.n_prompts = 0
context.n_server_predict = None
context.slot_save_path = None
@ -71,6 +73,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.stop_string = []
context.tasks_result = []
context.concurrent_tasks = []
context.prompts = []
@ -241,19 +244,35 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
await make_completion_request(context, api_error)
@step('an ongoing completion request')
@async_run_until_complete
async def step_request_ongoing_completion(context):
await make_completion_request(context, '', True)
async def make_completion_request(context, api_error: Literal['raised'] | str, ongoing=False):
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
prompt = context.prompts[-1] if ongoing else context.prompts.pop()
completion = await request_completion(prompt,
seeds[0] if seeds is not None else seeds,
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
n_keep=context.n_keep,
n_truncate=context.n_truncate,
cache_prompt=context.cache_prompt,
stop_string=context.stop_string,
id_slot=context.id_slot,
expect_api_error=expect_api_error,
user_api_key=context.user_api_key,
temperature=context.temperature)
context.tasks_result.append(completion)
if ongoing and isinstance(completion, dict) and not expect_api_error:
context.prompts[-1] += completion['content']
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
@ -336,6 +355,16 @@ def step_max_tokens(context, max_tokens):
context.n_predict = max_tokens
@step('{n_keep:d} tokens to keep')
def step_keep_tokens(context, n_keep):
context.n_keep = n_keep
@step('{n_truncate:d} tokens to truncate')
def step_truncate_tokens(context, n_truncate):
context.n_truncate = n_truncate
@step('a response format {response_format}')
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
@ -399,6 +428,11 @@ def step_bos_token(context, bos):
context.bos = bos
@step('a list of stop strings {stop_list}')
def step_stop_string(context, stop_list):
context.stop_string = json.loads(stop_list)
@step('a prefix prompt')
def step_prompt_prefix(context):
context.prompt_prefix = context_text(context)
@ -510,6 +544,11 @@ def step_many_prompts(context, num_prompts, prompt, seed):
context.n_prompts = len(context.prompts)
@step('an ongoing prompt')
def step_a_ongoing_prompt(context):
context.prompts[-1] += context_text(context)
@step('concurrent completion requests')
@async_run_until_complete()
async def step_concurrent_completion_requests(context):
@ -861,7 +900,10 @@ async def request_completion(prompt,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None,
n_keep=0,
n_truncate=0,
cache_prompt=False,
stop_string=None,
id_slot=None,
expect_api_error=None,
user_api_key=None,
@ -884,7 +926,10 @@ async def request_completion(prompt,
"prompt": prompt,
"input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1,
"n_keep": n_keep,
"n_truncate": n_truncate,
"cache_prompt": cache_prompt,
"stop": stop_string if stop_string is not None else [],
"id_slot": id_slot,
"seed": seed if seed is not None else 42,
"temperature": temperature if temperature is not None else 0.8,

View file

@ -0,0 +1,44 @@
# run with: ./tests.sh --no-skipped --tags truncation
@truncation
@slow
Feature: Chat truncation
Background: Server startup
Given a server listening on localhost:8080
And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models
And prompt caching is enabled
And a list of stop strings ["\n"]
And 82 tokens to keep
And 256 KV cache size
And 32 server max tokens to predict
Then the server is starting
Then the server is healthy
Scenario: Correctly truncate the prompt when the prompt exceeds the context size
Given a prompt:
"""
Continue the chat below.
Me: Hey there, how's it going?
You: I'm doing well, thanks for asking! How are you?
Me: I'm doing good, just trying to get some work done. How's your day?
You: My day has been pretty productive so far. I've been working on a new project.
Me: That's great to hear! What's the new project you're working on?
You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it.
Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch?
You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it.
Me: That's really nice, are you happy with the progress so far?
"""
And an ongoing completion request
Then -1 tokens are predicted matching You:
Given an ongoing prompt:
"""
Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
"""
And 52 tokens to truncate
And a completion request with no api error
Then -1 tokens are predicted matching You:
# 28 because '\n' stop string is not pushed to the context
And 28 prompt tokens are processed

View file

@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ]
then
# Start @llama.cpp scenario
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey|truncation' --tags llama.cpp
else
behave "$@"
fi