From 3b9625032ce4dc0b2c78c5d31910938ecc769395 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Nov 2024 20:34:16 +0200 Subject: [PATCH] f16 vec --- ggml/src/ggml-metal.m | 6 +- ggml/src/ggml-metal.metal | 152 +++++++++++++++++++------------------- 2 files changed, 77 insertions(+), 81 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index aecd6bc02..25f4d7b82 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3294,10 +3294,10 @@ static void ggml_metal_encode_node( // for each query, we load it as f16 in shared memory (ne00) // and store the attention scores (nqptg x ncpsg) as f32 // - // 2*ne00*(nsg) - // each simdgroup has a full f32 head vector in shared mem to accumulate results + // ne00*(nsg) + // each simdgroup has a full f16 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 8eb3faa86..20b104611 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3219,7 +3219,7 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*TS + 0]; + const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; @@ -3292,19 +3292,21 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_ #undef FA_TYPES template< - typename q4_t, + typename q4_t, // query types in shared memory typename q4x4_t, - typename k4x4_t, - typename v4x4_t, - typename s_t, // attention accumulation types + typename k4x4_t, // key types in shared memory + typename v4x4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types typename s4_t, typename s4x4_t, - typename o4x4_t, - typename block_q, + typename o4x4_t, // attention accumulation types + typename kd4x4_t, // key type in device memory short nl_k, - void (*deq_k)(device const block_q *, short, thread k4x4_t &), + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory short nl_v, - void (*deq_v)(device const block_q *, short, thread v4x4_t &), + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short D, // head size short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup @@ -3333,14 +3335,14 @@ kernel void kernel_flash_attn_ext_vec( constant float & max_bias, constant float & m0, constant float & m1, - constant uint32_t & n_head_log2, + constant uint16_t & n_head_log2, constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -3353,16 +3355,14 @@ kernel void kernel_flash_attn_ext_vec( const short NW4 = NW/4; const short SH = C; // shared memory per simdgroup in (half) - const short SF = sizeof(s_t)/sizeof(half); + const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short T = D + SF*nsg*SH; // shared memory size per query in (half) - - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4 - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4 - threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + SF*sgitg*SH + 1*D); // same as above but in half4 - threadgroup s4x4_t * sr4x4 = (threadgroup s4x4_t *) (shared + SF*sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4 + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4 + threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in half4 + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NW4]; @@ -3374,7 +3374,7 @@ kernel void kernel_flash_attn_ext_vec( if (iq1 < ne01) { sq4[i] = (q4_t) q4[i]; } else { - sq4[i] = (q4_t) (float4) 0.0f; + sq4[i] = (q4_t) 0.0f; } } @@ -3385,14 +3385,14 @@ kernel void kernel_flash_attn_ext_vec( // zero out shared memory SH for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = (s4_t) (float4) 0.0f; + ss4[i] = (s4_t) 0.0f; } threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = 0.0f; - float M = -FLT_MAX/2; + half S = 0.0f; + half M = -__FLT16_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%8; @@ -3406,25 +3406,25 @@ kernel void kernel_flash_attn_ext_vec( const short ikv3 = iq3/(ne03/ne_12_3); // load the queries from shared memory into local memory - k4x4_t mq[D16/NW4]; + q4x4_t mq[D16/NW4]; for (short ii = 0; ii < D16; ii += NW4) { - mq[ii/NW4] = (k4x4_t) sq4x4[ii + tx]; + mq[ii/NW4] = sq4x4[ii + tx]; } // pointer to the mask device const half * mp = (device const half *) (mask + iq1*nb31); - float slope = 1.0f; + half slope = 1.0f; // ALiBi if (max_bias > 0.0f) { - const uint32_t h = iq2; + const short h = iq2; - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < n_head_log2 ? m0 : m1; + const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = pow(base, exp); + slope = pow(base, exph); } // loop over the KV cache @@ -3439,9 +3439,9 @@ kernel void kernel_flash_attn_ext_vec( { // each simdgroup processes 1 query and 4 keys for (short cc = 0; cc < C/4; ++cc) { - s_t mqk = 0.0; + qk_t mqk = 0.0; - device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { @@ -3487,20 +3487,18 @@ kernel void kernel_flash_attn_ext_vec( // online softmax { - const short p = tiisg; - - const float m = M; - const float s = ss[p]; + const half m = M; + const half s = ss[tiisg]; M = simd_max(max(M, s)); - const float ms = exp(m - M); - const float vs = exp(s - M); + const half ms = exp(m - M); + const half vs = exp(s - M); S = S*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; + ss[tiisg] = vs; // O = diag(ms)*O #pragma unroll @@ -3515,9 +3513,9 @@ kernel void kernel_flash_attn_ext_vec( { #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); - const s4x4_t ms(ss[4*cc + ty]); + const v4x4_t ms(ss[4*cc + ty]); #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { @@ -3526,7 +3524,7 @@ kernel void kernel_flash_attn_ext_vec( v4x4_t mv; deq_v(pv4 + i/nl_v, i%nl_v, mv); - lo[ii/NW4] += mv*ms; + lo[ii/NW4] += (o4x4_t)(mv*ms); } } } @@ -3572,22 +3570,22 @@ kernel void kernel_flash_attn_ext_vec( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; - const float M = max(M0, M1); + const half M = max(M0, M1); - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - const float S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[0] = (s_t) S; - ss[1] = (s_t) M; + ss[0] = S; + ss[1] = M; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -3611,33 +3609,31 @@ kernel void kernel_flash_attn_ext_vec( } } -// NOTE: can use half instead of float precision for some extra perf -// however, by default use F32 since the op should be mostly memory bandwidth bound - #define FA_TYPES \ half4, half4x4, \ - float4x4, \ - float4x4, \ + half4x4, \ + half4x4, \ + float, \ float, float4, float4x4, \ - float4x4 + half4x4 -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES