llama : add defrag_thold parameter

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-26 18:19:23 +02:00
parent 30c29f44cc
commit 4e35db1a81
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 6 additions and 3 deletions

View file

@ -1641,6 +1641,7 @@ struct llama_cparams {
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;
bool mul_mat_q;
bool offload_kqv;
@ -8007,12 +8008,11 @@ static int llama_decode_internal(
}
// decide if we need to defrag the kv cache
// TODO: should become configurable
{
const float fragmentation = kv_self.n >= 512 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > 0.1f) {
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
llama_kv_cache_defrag(kv_self);
@ -11677,6 +11677,7 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
@ -11841,6 +11842,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling;

View file

@ -243,6 +243,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;