From 03cb5cda6d6a21d98ac14cf0142a5573c0d1d42c Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Mon, 22 Jul 2024 17:06:45 +0800 Subject: [PATCH] use sliding window for phi3 --- src/llama.cpp | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 6046e5615..e0299b950 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4974,6 +4974,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_PHI3: { + hparams.n_swa = 2048; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -10843,7 +10845,7 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -10901,7 +10903,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -14108,18 +14110,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { "causal attention is not supported by this model" ); - if (lctx.inp_KQ_mask) { + if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = nullptr; float * data_swa = nullptr; + if (lctx.inp_KQ_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + data_swa = (float *) lctx.inp_KQ_mask->data; + } + if (lctx.inp_KQ_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); data_swa = (float *) lctx.inp_KQ_mask_swa->data; } @@ -14142,7 +14149,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { f = 0.0f; } } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + + if (data) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } // may need to cut off old tokens for sliding window if (data_swa) { @@ -14154,9 +14164,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } }