From 9e62e7e10e493db1d55cc0b531f4f44a44ea4a8c Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Sun, 22 Sep 2024 22:03:18 +0800 Subject: [PATCH] [metal-kernel] add flash_attn_ext_scalar_f16 implementation --- ggml/src/ggml-metal.metal | 288 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b2000323..122bfb5eb 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2799,6 +2799,294 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +half dequantize_load_f16(device const half *xb, short il) { + return xb[il]; +} + +half dequantize_load_q8_0(device const block_q8_0 *xb, short il) { + device const block_q8_0 *xb_ = &xb[il / QK8_0]; + return xb_->d * xb_->qs[il % QK8_0]; +} + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_scalar_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sr = (threadgroup half *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half lo[D/NW]; + + // load heads from Q to shared memory + device const float * q_ = (device const float *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D; i += NW) { + if (iq1 < ne01) { + sq[i] = (half) q_[i]; + } else { + sq[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH; i += NW) { + ss[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half mq[D]; + + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + mq[i] = sq[i]; + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +// #pragma unroll + for (short cc = 0; cc < C; ++cc) { + float mqk = 0.0; + + device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + mqk += mq[i] * dequantize_load(pk, i); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + if (mask != q) { + mqk += (mp[ic + cc])*slope; + } + + ss[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +// #pragma unroll + for (short cc = 0; cc < C; ++cc) { + device const block_q * pv = (device const block_q *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += dequantize_load(pv, i) * ss[cc]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + sr[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + sr[i] = sr[i]*ms0 + sr[i + r*D]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D; ii += NW) { + short i = ii + tiisg; + dst[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D + i] = sr[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_scalar_f16_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; + +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; +template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16; + template kernel void kernel_cpy( device const void * src0,