From f6416d449362c350d4211525d7782675c6d244fd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 12:59:59 +0200 Subject: [PATCH] wip : good version 8x32 --- ggml-metal.m | 4 +-- ggml-metal.metal | 77 +++++++++++++++++++++--------------------------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 93b499a12..a3191e35a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 8; + const int64_t ncpsg = 32; //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d4719ea0..9c5d1ed2e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2072,9 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } } - if (tiisg < 1) { + if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; + ss[j*T + 0 + tiisg] = 0.0h; } } @@ -2128,36 +2128,26 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - //{ - // bool skip = true; - // for (int64_t j = 0; j < Q; ++j) { - // skip = skip && (mp[j][iic] == -INFINITY); - // } - // if (skip) { - // continue; - // } - //} - { simdgroup_half8x8 mk; - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + (iic*nb11 + ik2*nb12 + ik3*nb13)); + for (int cc = 0; cc < 4; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } - - simdgroup_store(mqk, ss, T, 0, false); } - // not sure why this barrier is needed - simdgroup_barrier(mem_flags::mem_none); - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg % C; + const int64_t p = tiisg; const half s = ss[j*T + p]*scale + (mp[j][iic + p]); @@ -2168,37 +2158,38 @@ kernel void kernel_flash_attn_ext_f16( const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32 + S[j] = S[j]*ms + simd_sum(vs); for (int64_t i = 0; i < L4; ++i) { ls4[j][i] *= ms; } - if (tiisg < C) { - ss[j*T + p] = vs; - } + ss[j*T + p] = vs; } { simdgroup_half8x8 mv; - simdgroup_half8x8 mp; - simdgroup_half8x8 mqkv; - - device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23)); - - // load mp - simdgroup_load(mp, ss, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (mv, pv + i*8, nb21/2, 0, false); - simdgroup_multiply(mqkv, mp, mv); - simdgroup_store (mqkv, ps + i*8, T, 0, false); + simdgroup_half8x8 mp[4]; + simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix(0.h); + + for (int cc = 0; cc < 4; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < 4; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); } } - // not sure why this barrier is needed too - threadgroup_barrier(mem_flags::mem_none); - for (int64_t j = 0; j < Q; ++j) { for (int64_t i = 0; i < L4; ++i) { ls4[j][i] += ps4[j*T4 + N4*i + tiisg]; @@ -2284,9 +2275,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 8>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 8>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 8>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0,