server: avoid breaking KV cache when prompt >= n_ctx
This commit is contained in:
parent
4dba7e8114
commit
91d94eeebd
1 changed files with 83 additions and 9 deletions
|
@ -100,10 +100,11 @@ struct slot_params {
|
|||
bool stream = true;
|
||||
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_truncate = 0;
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
|
@ -170,6 +171,7 @@ struct server_slot {
|
|||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool has_next_token = true;
|
||||
bool shifted = false;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
|
@ -205,6 +207,7 @@ struct server_slot {
|
|||
void reset() {
|
||||
n_prompt_tokens = 0;
|
||||
generated_text = "";
|
||||
shifted = false;
|
||||
truncated = false;
|
||||
stopped_eos = false;
|
||||
stopped_word = false;
|
||||
|
@ -701,7 +704,7 @@ struct server_context {
|
|||
|
||||
return res > 0;
|
||||
}
|
||||
|
||||
//MARK: Init
|
||||
void init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||
|
||||
|
@ -854,6 +857,7 @@ struct server_context {
|
|||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
@ -1282,6 +1286,7 @@ struct server_context {
|
|||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"n_truncate", slot.params.n_truncate},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
|
@ -1369,6 +1374,7 @@ struct server_context {
|
|||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
{"stopped_word", slot.stopped_word},
|
||||
|
@ -1784,6 +1790,7 @@ struct server_context {
|
|||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated}
|
||||
});
|
||||
|
||||
|
@ -1857,7 +1864,7 @@ struct server_context {
|
|||
|
||||
slot.n_past -= n_discard;
|
||||
|
||||
slot.truncated = true;
|
||||
slot.shifted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1892,6 +1899,7 @@ struct server_context {
|
|||
{"n_past", slot.n_past},
|
||||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||
{"shifted", slot.shifted},
|
||||
{"truncated", slot.truncated}
|
||||
});
|
||||
}
|
||||
|
@ -1988,9 +1996,11 @@ struct server_context {
|
|||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
const int n_truncate = slot.params.n_truncate;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size;
|
||||
const int erased_tokens = n_truncate ? n_truncate : half_ctx;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
|
@ -1998,7 +2008,7 @@ struct server_context {
|
|||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_tokens,
|
||||
prompt_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
@ -2006,6 +2016,17 @@ struct server_context {
|
|||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) {
|
||||
LOG_ERROR("prompt - n_truncate > n_ctx", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_truncate", n_truncate},
|
||||
});
|
||||
}
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
|
@ -2016,6 +2037,29 @@ struct server_context {
|
|||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
if (n_truncate && slot.params.cache_prompt) {
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) {
|
||||
slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i];
|
||||
}
|
||||
|
||||
const int new_cache_size = slot.cache_tokens.size() - n_truncate;
|
||||
if (new_cache_size >= 0) {
|
||||
slot.cache_tokens.resize(new_cache_size);
|
||||
|
||||
LOG_VERBOSE("cache tokens shift", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"n_truncate", n_truncate},
|
||||
{"new_cache_size", new_cache_size},
|
||||
{"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())},
|
||||
});
|
||||
} // else somebody trying to use n_truncate w/o previous cache
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
|
@ -2030,6 +2074,12 @@ struct server_context {
|
|||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
LOG_INFO("[cached_tokens, prompt_tokens]", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "common_part", slot.n_past}
|
||||
});
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
|
@ -2060,6 +2110,30 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// shift KV cache if needed
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
const int n_truncate = slot.params.n_truncate;
|
||||
if (n_truncate && slot.params.cache_prompt) {
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate);
|
||||
|
||||
LOG_INFO("kv cache rm", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "p0", n_keep },
|
||||
{ "p1", n_keep + n_truncate }
|
||||
});
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate);
|
||||
|
||||
LOG_INFO("kv cache add", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "p0", n_keep + n_truncate },
|
||||
{ "p1", slot.n_past },
|
||||
{ "delta", -n_truncate }
|
||||
});
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue