rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.kv_size = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@ -121,12 +121,12 @@ int main(int argc, char ** argv) {
// total length of the sequences including the prompt
const int n_len = n_tokens_all + n_predict;
const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx);
const int kv_size = llama_kv_size(ctx) - n_keep;
const int n_kv_req = llama_kv_size(ctx);
const int n_batch = ctx_params.n_batch;
const int n_batch_grp = ctx_params.n_batch/n_grp;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
LOG_TEE("\n%s: n_len = %d, kv_size = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, kv_size, n_kv_req, n_grp, n_batch);
// print the prompt token-by-token
@ -140,7 +140,7 @@ int main(int argc, char ** argv) {
int n_past = 0;
// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
for (int i = 0; i < kv_size; i += n_batch) {
if (i > 0 && n_grp > 1) {
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
const int ib = i/n_batch - 1;
@ -174,13 +174,13 @@ int main(int argc, char ** argv) {
}
}
for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
for (int i = kv_size; i < n_tokens_all; i += n_batch) {
const int n_discard = n_batch;
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, kv_size, -n_discard);
n_past -= n_discard;
@ -203,13 +203,13 @@ int main(int argc, char ** argv) {
}
{
const int n_discard = n_past - n_ctx + n_predict;
const int n_discard = n_past - kv_size + n_predict;
if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, kv_size, -n_discard);
n_past -= n_discard;
}