From 97c27f59f69be668262ab50fdab327756a45fe25 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 May 2024 13:51:00 +0300 Subject: [PATCH] ggml : ggml_flash_attn_ext() support ALiBi (Metal) --- ggml-metal.m | 76 ++++++++++++++++++++++---------------- ggml-metal.metal | 50 +++++++++++++++++++++---- llama.cpp | 6 +++ tests/test-backend-ops.cpp | 20 ++++++---- 4 files changed, 105 insertions(+), 47 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1f8943bda..afc4e4beb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute( "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - const int64_t ne31 = src3 ? src3->ne[1] : 0; + //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); @@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute( const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); id pipeline = nil; @@ -2562,34 +2571,37 @@ static enum ggml_status ggml_metal_graph_compute( } [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; - [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&scale length:sizeof( float) atIndex:26]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:27]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:28]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:29]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml-metal.metal b/ggml-metal.metal index 641e576e2..ee9de57a3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16( // prepare diagonal scale matrix simdgroup_float8x8 mscale(scale); + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope simdgroup_half8x8 mm; simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, TF, 0, false); @@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16( const short T = D + 2*nsg*SH; // shared memory size per query in (half) + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix @@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16( mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 1); - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope if (tiisg == 0) { float4 mm = (float4) mp4[ic/4 + cc]; - mqk = mqk*scale + mm; + mqk = mqk*scale + mm*slope; ss4[cc] = mqk; } @@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } diff --git a/llama.cpp b/llama.cpp index 3a66c3b5e..74d683806 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10985,6 +10985,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } } else { // when using kv cache, the mask needs to match the kv cache size diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 216a01359..ab94abc72 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1486,23 +1486,25 @@ struct test_flash_attn_ext : public test_case { const int64_t kv; // kv size const int64_t nb; // batch size + const float max_bias; // ALiBi + std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); + return VARS_TO_STR5(hs, nh, kv, nb, max_bias); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f) + : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias); return out; } }; @@ -2176,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #else for (int hs : { 64, 80, 128, 256, }) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + for (float max_bias : {0.0f, 8.0f}) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias)); + } } } }