server: avoid breaking KV cache when prompt >= n_ctx
This commit is contained in:
parent
4dba7e8114
commit
91d94eeebd
1 changed files with 83 additions and 9 deletions
|
@ -103,6 +103,7 @@ struct slot_params {
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||||
|
int32_t n_truncate = 0;
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
@ -170,6 +171,7 @@ struct server_slot {
|
||||||
bool infill = false;
|
bool infill = false;
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
|
bool shifted = false;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
bool stopped_eos = false;
|
bool stopped_eos = false;
|
||||||
bool stopped_word = false;
|
bool stopped_word = false;
|
||||||
|
@ -205,6 +207,7 @@ struct server_slot {
|
||||||
void reset() {
|
void reset() {
|
||||||
n_prompt_tokens = 0;
|
n_prompt_tokens = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
shifted = false;
|
||||||
truncated = false;
|
truncated = false;
|
||||||
stopped_eos = false;
|
stopped_eos = false;
|
||||||
stopped_word = false;
|
stopped_word = false;
|
||||||
|
@ -701,7 +704,7 @@ struct server_context {
|
||||||
|
|
||||||
return res > 0;
|
return res > 0;
|
||||||
}
|
}
|
||||||
|
//MARK: Init
|
||||||
void init() {
|
void init() {
|
||||||
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
||||||
|
|
||||||
|
@ -854,6 +857,7 @@ struct server_context {
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||||
|
slot.params.n_truncate = json_value(data, "n_truncate", default_params.n_truncate);
|
||||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
@ -1282,6 +1286,7 @@ struct server_context {
|
||||||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_discard", slot.params.n_discard},
|
{"n_discard", slot.params.n_discard},
|
||||||
|
{"n_truncate", slot.params.n_truncate},
|
||||||
{"ignore_eos", ignore_eos},
|
{"ignore_eos", ignore_eos},
|
||||||
{"stream", slot.params.stream},
|
{"stream", slot.params.stream},
|
||||||
{"logit_bias", slot.sparams.logit_bias},
|
{"logit_bias", slot.sparams.logit_bias},
|
||||||
|
@ -1369,6 +1374,7 @@ struct server_context {
|
||||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||||
{"generation_settings", get_formated_generation(slot)},
|
{"generation_settings", get_formated_generation(slot)},
|
||||||
{"prompt", slot.prompt},
|
{"prompt", slot.prompt},
|
||||||
|
{"shifted", slot.shifted},
|
||||||
{"truncated", slot.truncated},
|
{"truncated", slot.truncated},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
{"stopped_word", slot.stopped_word},
|
{"stopped_word", slot.stopped_word},
|
||||||
|
@ -1784,6 +1790,7 @@ struct server_context {
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
{"n_system_tokens", system_tokens.size()},
|
{"n_system_tokens", system_tokens.size()},
|
||||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||||
|
{"shifted", slot.shifted},
|
||||||
{"truncated", slot.truncated}
|
{"truncated", slot.truncated}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1857,7 +1864,7 @@ struct server_context {
|
||||||
|
|
||||||
slot.n_past -= n_discard;
|
slot.n_past -= n_discard;
|
||||||
|
|
||||||
slot.truncated = true;
|
slot.shifted = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1892,6 +1899,7 @@ struct server_context {
|
||||||
{"n_past", slot.n_past},
|
{"n_past", slot.n_past},
|
||||||
{"n_system_tokens", system_tokens.size()},
|
{"n_system_tokens", system_tokens.size()},
|
||||||
{"n_cache_tokens", slot.cache_tokens.size()},
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
||||||
|
{"shifted", slot.shifted},
|
||||||
{"truncated", slot.truncated}
|
{"truncated", slot.truncated}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1988,9 +1996,11 @@ struct server_context {
|
||||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
const int n_truncate = slot.params.n_truncate;
|
||||||
|
|
||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int half_ctx = ((slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size) * n_block_size;
|
||||||
|
const int erased_tokens = n_truncate ? n_truncate : half_ctx;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(
|
std::vector<llama_token> new_tokens(
|
||||||
prompt_tokens.begin(),
|
prompt_tokens.begin(),
|
||||||
|
@ -1998,7 +2008,7 @@ struct server_context {
|
||||||
|
|
||||||
new_tokens.insert(
|
new_tokens.insert(
|
||||||
new_tokens.end(),
|
new_tokens.end(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
prompt_tokens.begin() + slot.params.n_keep + erased_tokens,
|
||||||
prompt_tokens.end());
|
prompt_tokens.end());
|
||||||
|
|
||||||
prompt_tokens = std::move(new_tokens);
|
prompt_tokens = std::move(new_tokens);
|
||||||
|
@ -2006,6 +2016,17 @@ struct server_context {
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
|
if (n_truncate && slot.n_prompt_tokens > slot.n_ctx) {
|
||||||
|
LOG_ERROR("prompt - n_truncate > n_ctx", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task},
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"n_truncate", n_truncate},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
LOG_VERBOSE("input truncated", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"id_task", slot.id_task},
|
{"id_task", slot.id_task},
|
||||||
|
@ -2016,6 +2037,29 @@ struct server_context {
|
||||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (n_truncate && slot.params.cache_prompt) {
|
||||||
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
|
for (size_t i = n_keep + n_truncate; i < slot.cache_tokens.size(); i++) {
|
||||||
|
slot.cache_tokens[i - n_truncate] = slot.cache_tokens[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int new_cache_size = slot.cache_tokens.size() - n_truncate;
|
||||||
|
if (new_cache_size >= 0) {
|
||||||
|
slot.cache_tokens.resize(new_cache_size);
|
||||||
|
|
||||||
|
LOG_VERBOSE("cache tokens shift", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task},
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"n_truncate", n_truncate},
|
||||||
|
{"new_cache_size", new_cache_size},
|
||||||
|
{"cache_tokens", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())},
|
||||||
|
});
|
||||||
|
} // else somebody trying to use n_truncate w/o previous cache
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2030,6 +2074,12 @@ struct server_context {
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
|
LOG_INFO("[cached_tokens, prompt_tokens]", {
|
||||||
|
{ "id_slot", slot.id },
|
||||||
|
{ "id_task", slot.id_task },
|
||||||
|
{ "common_part", slot.n_past}
|
||||||
|
});
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (int i = 0; i < slot.n_past; ++i) {
|
for (int i = 0; i < slot.n_past; ++i) {
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||||
|
@ -2060,6 +2110,30 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shift KV cache if needed
|
||||||
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
|
const int n_truncate = slot.params.n_truncate;
|
||||||
|
if (n_truncate && slot.params.cache_prompt) {
|
||||||
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_truncate);
|
||||||
|
|
||||||
|
LOG_INFO("kv cache rm", {
|
||||||
|
{ "id_slot", slot.id },
|
||||||
|
{ "id_task", slot.id_task },
|
||||||
|
{ "p0", n_keep },
|
||||||
|
{ "p1", n_keep + n_truncate }
|
||||||
|
});
|
||||||
|
|
||||||
|
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_truncate, -1, -n_truncate);
|
||||||
|
|
||||||
|
LOG_INFO("kv cache add", {
|
||||||
|
{ "id_slot", slot.id },
|
||||||
|
{ "id_task", slot.id_task },
|
||||||
|
{ "p0", n_keep + n_truncate },
|
||||||
|
{ "p1", slot.n_past },
|
||||||
|
{ "delta", -n_truncate }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue