rename n_ctx to kv_size

This commit is contained in:
Pierrick HYMBERT 2024-02-18 20:59:26 +01:00 committed by Georgi Gerganov
parent ef96e8b1f7
commit 606873401c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
48 changed files with 403 additions and 393 deletions

View file

@ -155,7 +155,7 @@ struct llama_client_slot
int64_t t_last_used = -1;
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t kv_size = 0; // KV size per slot
int32_t n_past = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
@ -325,7 +325,7 @@ struct llama_server_context
bool all_slots_are_idle = false;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
int32_t kv_size; // total KV Cache for all clients / slots
// system prompt
bool system_need_update = false;
@ -369,8 +369,8 @@ struct llama_server_context
return false;
}
if (params.n_ctx < 2048) { // request larger context for the image embedding
params.n_ctx = 2048;
if (params.kv_size < 2048) { // request larger context for the image embedding
params.kv_size = 2048;
}
}
@ -392,7 +392,7 @@ struct llama_server_context
}
}
n_ctx = llama_n_ctx(ctx);
kv_size = llama_kv_size(ctx);
add_bos_token = llama_should_add_bos_token(model);
@ -403,7 +403,7 @@ struct llama_server_context
// create slots
all_slots_are_idle = true;
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
const int32_t kv_size_slot = kv_size / params.n_parallel;
LOG_TEE("Available slots:\n");
for (int i = 0; i < params.n_parallel; i++)
@ -411,10 +411,10 @@ struct llama_server_context
llama_client_slot slot;
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.kv_size = kv_size_slot;
slot.n_predict = params.n_predict;
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
LOG_TEE(" -> Slot %i - max KV Size: %i\n", slot.id, kv_size_slot);
const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;
@ -423,7 +423,7 @@ struct llama_server_context
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
//GGML_ASSERT(kv_size >= n_ctx_train * ga_n && "kv_size must be at least n_ctx_train * ga_n"); // NOLINT
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
}
@ -439,7 +439,7 @@ struct llama_server_context
default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1;
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
batch = llama_batch_init(kv_size, 0, params.n_parallel);
}
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
@ -1065,7 +1065,7 @@ struct llama_server_context
}
return json {
{"n_ctx", slot.n_ctx},
{"kv_size", slot.kv_size},
{"n_predict", slot.n_predict},
{"model", params.model_alias},
{"seed", slot.params.seed},
@ -1474,7 +1474,7 @@ struct llama_server_context
{
if (slot.ga_n == 1)
{
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.kv_size)
{
// Shift context
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
@ -1496,7 +1496,7 @@ struct llama_server_context
slot.truncated = true;
LOG_VERBOSE("context shift", {
{ "n_ctx", n_ctx },
{ "kv_size", kv_size },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
});
@ -1598,12 +1598,12 @@ struct llama_server_context
{
slot.params.n_keep = slot.num_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
slot.params.n_keep = std::min(slot.kv_size - 4, slot.params.n_keep);
// if input prompt is too big, truncate it
if (slot.num_prompt_tokens >= slot.n_ctx)
if (slot.num_prompt_tokens >= slot.kv_size)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_left = slot.kv_size - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
@ -1611,7 +1611,7 @@ struct llama_server_context
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"kv_size", slot.kv_size},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
@ -1620,7 +1620,7 @@ struct llama_server_context
prompt_tokens = new_tokens;
slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
GGML_ASSERT(slot.num_prompt_tokens < slot.kv_size);
}
if (!slot.params.cache_prompt)
@ -1873,7 +1873,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -kv N, --kv-size N Specify the total size of the KV cache (default: %d)\n", params.n_ctx);
printf(" -kv N, --kv-size N Specify the total size of the KV cache (default: %d)\n", params.kv_size);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
@ -2043,16 +2043,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
server_print_usage(argv[0], default_params, default_sparams);
exit(0);
}
else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size")
else if (arg == "-c" || arg == "--ctx-size" || arg == "--kv_size")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]);
LOG_WARNING("-c,--ctx-size,--ctx_size option is deprecated, use --kv-size instead",
{{"--ctx_size", params.n_ctx}});
params.kv_size = std::stoi(argv[i]);
LOG_WARNING("-c,--ctx-size,--kv_size option is deprecated, use --kv-size instead",
{{"--kv_size", params.kv_size}});
}
else if (arg == "-kv" || arg == "--kv-size" || arg == "--kv_size")
{
@ -2061,7 +2061,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]);
params.kv_size = std::stoi(argv[i]);
}
else if (arg == "--rope-scaling")
{