rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -14,7 +14,8 @@ In this section, we cover the most commonly used options for running the `infill
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
- `-c N`, `--ctx-size N`: Deprecated, use `--kv-size` instead.
- `-kv N`, `--kv-size N`: Specify the total size of the KV cache for the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
## Input Prompts

View file

@ -135,9 +135,9 @@ int main(int argc, char ** argv) {
return 0;
}
if (params.n_ctx != 0 && params.n_ctx < 8) {
if (params.kv_size != 0 && params.kv_size < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
params.kv_size = 8;
}
if (params.instruct) {
printf("\n************\n");
@ -225,12 +225,12 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
const int kv_size = llama_kv_size(ctx);
LOG("kv_size: %d\n", kv_size);
if (n_ctx > n_ctx_train) {
if (kv_size > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
__func__, n_ctx_train, kv_size);
}
// print system information
@ -291,8 +291,8 @@ int main(int argc, char ** argv) {
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
if ((int) embd_inp.size() > kv_size - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), kv_size - 4);
return 1;
}
@ -366,7 +366,7 @@ int main(int argc, char ** argv) {
}
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("generate: kv_size = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", kv_size, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
LOG_TEE("\n##### Infill mode #####\n\n");
@ -416,9 +416,9 @@ int main(int argc, char ** argv) {
while (n_remain != 0 || params.interactive) {
// predict
if (!embd.empty()) {
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
// Note: kv_size - 4 here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4;
int max_embd_size = kv_size - 4;
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int) embd.size() > max_embd_size) {
@ -434,8 +434,8 @@ int main(int argc, char ** argv) {
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
// - take half of the last (kv_size - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > kv_size) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
@ -444,8 +444,8 @@ int main(int argc, char ** argv) {
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG("context full, swapping: n_past = %d, n_left = %d, kv_size = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, kv_size, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);