From 166e60bf9b74e51564b6c6e72ae35a11af5709db Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 May 2024 11:33:34 +0300 Subject: [PATCH] ggml : ggml_flash_attn_ext() support ALiBi (CPU) --- ggml.c | 31 ++++++++++++++++++++++--------- ggml.h | 3 ++- llama.cpp | 2 +- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index d218e84d3..60091186e 100644 --- a/ggml.c +++ b/ggml.c @@ -6436,7 +6436,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale) { + float scale, + float max_bias) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) if (mask) { @@ -6458,7 +6459,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale }; + float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -6478,7 +6479,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos + ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_ff @@ -13308,8 +13309,8 @@ static void ggml_compute_forward_soft_max_f32( // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -15524,8 +15525,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -15534,6 +15544,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const int h = iq2; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float S = 0.0f; float M = -INFINITY; @@ -15557,7 +15570,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } @@ -15628,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[1]) { + switch (dst->op_params[2]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { diff --git a/ggml.h b/ggml.h index fe53c3362..76c332831 100644 --- a/ggml.h +++ b/ggml.h @@ -1731,7 +1731,8 @@ extern "C" { struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale); + float scale, + float max_bias); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index 227bad97d..3a66c3b5e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6537,7 +6537,7 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);