server: avoid breaking KV cache when prompt >= n_ctx

This commit is contained in:
Paulo de Castro 2024-04-28 00:34:07 -03:00
parent 4dba7e8114
commit 91d94eeebd

View file

@ -100,10 +100,11 @@ struct slot_params {
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_truncate = 0;
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
@ -170,6 +171,7 @@ struct server_slot {
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool shifted = false;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@ -190,7 +192,7 @@ struct server_slot {
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
int32_t n_past_se = 0; // self-extend
int32_t n_past_se = 0; // self-extend
// stats
size_t n_sent_text = 0; // number of sent text character
@ -205,6 +207,7 @@ struct server_slot {
void reset() {
n_prompt_tokens = 0;
generated_text = "";
shifted = false;
truncated = false;
stopped_eos = false;
stopped_word = false;
@ -701,7 +704,7 @@ struct server_context {
return res > 0;
}
//MARK: Init
void init() {
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
@ -854,6 +857,7 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@ -1282,6 +1286,7 @@ struct server_context {
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"n_truncate", slot.params.n_truncate},
{"ignore_eos", ignore_eos},
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
@ -1369,6 +1374,7 @@ struct server_context {
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"shifted", slot.shifted},
{"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
@ -1784,6 +1790,7 @@ struct server_context {
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"shifted", slot.shifted},
{"truncated", slot.truncated}
});
@ -1857,7 +1864,7 @@ struct server_context {
slot.n_past -= n_discard;
slot.truncated = true;
slot.shifted = true;
}
}
}
@ -1892,6 +1899,7 @@ struct server_context {
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"shifted", slot.shifted},
{"truncated", slot.truncated}
});
}
@ -1988,9 +1996,11 @@ struct server_context {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_truncate = slot.params.n_truncate;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size;
const int erased_tokens = n_truncate ? n_truncate : half_ctx;
std::vector<llama_token> new_tokens(
prompt_tokens.begin(),
@ -1998,7 +2008,7 @@ struct server_context {
new_tokens.insert(
new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.begin() + slot.params.n_keep + erased_tokens,
prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
@ -2006,6 +2016,17 @@ struct server_context {
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) {
LOG_ERROR("prompt - n_truncate > n_ctx", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_truncate", n_truncate},
});
}
LOG_VERBOSE("input truncated", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
@ -2016,6 +2037,29 @@ struct server_context {
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
if (n_truncate && slot.params.cache_prompt) {
const int n_keep = slot.params.n_keep + add_bos_token;
for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i];
}
const int new_cache_size = slot.cache_tokens.size() - n_truncate;
if (new_cache_size >= 0) {
slot.cache_tokens.resize(new_cache_size);
LOG_VERBOSE("cache tokens shift", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_truncate", n_truncate},
{"new_cache_size", new_cache_size},
{"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())},
});
} // else somebody trying to use n_truncate w/o previous cache
}
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@ -2030,6 +2074,12 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
LOG_INFO("[cached_tokens, prompt_tokens]", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "common_part", slot.n_past}
});
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
@ -2060,6 +2110,30 @@ struct server_context {
}
}
// shift KV cache if needed
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_truncate = slot.params.n_truncate;
if (n_truncate && slot.params.cache_prompt) {
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate);
LOG_INFO("kv cache rm", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "p0", n_keep },
{ "p1", n_keep + n_truncate }
});
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate);
LOG_INFO("kv cache add", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "p0", n_keep + n_truncate },
{ "p1", slot.n_past },
{ "delta", -n_truncate }
});
}
// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {