rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -157,9 +157,9 @@ int main(int argc, char ** argv) {
return 0;
}
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
if (params.kv_size != 0 && params.kv_size < 8) {
LOG_TEE("%s: warning: minimum KV size is 8, using minimum size.\n", __func__);
params.kv_size = 8;
}
if (params.rope_freq_base != 0.0) {
@ -208,12 +208,12 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
const int kv_size = llama_kv_size(ctx);
LOG("kv_size: %d\n", kv_size);
if (n_ctx > n_ctx_train) {
if (kv_size > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
__func__, n_ctx_train, kv_size);
}
// print system information
@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: The session file is empty. A new session will be initialized.\n", __func__);
} else {
// The file exists and is not empty
session_tokens.resize(n_ctx);
session_tokens.resize(kv_size);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
@ -289,8 +289,8 @@ int main(int argc, char ** argv) {
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
if ((int) embd_inp.size() > kv_size - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), kv_size - 4);
return 1;
}
@ -450,7 +450,7 @@ int main(int argc, char ** argv) {
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("generate: kv_size = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", kv_size, params.n_batch, params.n_predict, params.n_keep);
// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
@ -463,7 +463,7 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
//GGML_ASSERT(kv_size >= n_ctx_train * ga_n && "kv_size must be at least n_ctx_train * grp_attn_n"); // NOLINT
LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
}
LOG_TEE("\n\n");
@ -514,9 +514,9 @@ int main(int argc, char ** argv) {
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
// Note: (kv_size - 4) here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4;
int max_embd_size = kv_size - 4;
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int) embd.size() > max_embd_size) {
@ -533,8 +533,8 @@ int main(int argc, char ** argv) {
// infinite text generation via context shifting
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
// - take half of the last (kv_size - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > kv_size) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
@ -543,8 +543,8 @@ int main(int argc, char ** argv) {
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG("context full, swapping: n_past = %d, n_left = %d, kv_size = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, kv_size, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
@ -666,7 +666,7 @@ int main(int argc, char ** argv) {
LOG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, kv_size);
}
}