From ed5496fb32e9888abeaa7672aaba4d4251671457 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 1 Jul 2024 12:35:47 +0200 Subject: [PATCH] update --- src/llama.cpp | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 71b7ef622..1f6763573 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2101,7 +2101,7 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; - uint32_t n_sliding = 0; // sliding window attention (SWA) + uint32_t n_swa = 0; // sliding window attention (SWA) float f_norm_eps; float f_norm_rms_eps; @@ -2665,7 +2665,7 @@ struct llama_context { struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // KQ mask per layer, used by sliding window attention (gemma 2) - struct ggml_tensor * inp_KQ_mask_SWA; + struct ggml_tensor * inp_KQ_mask_swa; // control vectors struct llama_control_vector cvec; @@ -4715,8 +4715,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_GEMMA2: { - hparams.n_sliding = 4096; // default value of gemma 2 - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false); + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); @@ -7794,7 +7794,7 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; - lctx.inp_KQ_mask_SWA = nullptr; + lctx.inp_KQ_mask_swa = nullptr; } void free() { @@ -7954,7 +7954,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); ggml_set_input(KQ_mask); if (sliding_window) { - lctx.inp_KQ_mask_SWA = KQ_mask; + lctx.inp_KQ_mask_swa = KQ_mask; } else { lctx.inp_KQ_mask = KQ_mask; } @@ -12689,14 +12689,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; - const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; - if (lctx.model.arch == LLM_ARCH_GEMMA2) { - GGML_ASSERT(lctx.inp_KQ_mask_SWA); - GGML_ASSERT(hparams.n_sliding > 0); - data = (float *) lctx.inp_KQ_mask->data; - data_swa = (float *) lctx.inp_KQ_mask_SWA->data; - // because layer masks are alternate for gemma 2, we only need to take first 2 layers + if (lctx.inp_KQ_mask_swa) { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; } // For causal attention, use only the previous KV cells @@ -12722,7 +12717,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { f = -INFINITY; } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;