From 2335086fd3687e450f850855d62a0b802b436d36 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Nov 2024 22:04:07 +0200 Subject: [PATCH] wip2 --- ggml/src/ggml-metal.metal | 80 +++++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 5f63d7e04..80e5dcc4e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = mask0 << 8; - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; } + + reg = (type4x4) reg_f; } template @@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = mask0 << 8; - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m; + reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m; } + + reg = (type4x4) reg_f; } template @@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int gh_mv = il ? 12 : 0; const int gh_bk = il ? 0 : 4; + float4x4 reg_f; + for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; @@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; + reg_f[i/2][2*(i%2) + 0] = d * x0 + md; + reg_f[i/2][2*(i%2) + 1] = d * x1 + md; } + + reg = (type4x4) reg_f; } template @@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const int gh_mv = il ? 12 : 0; const int gh_bk = il ? 0 : 4; + float4x4 reg_f; + for (int i = 0; i < 8; i++) { // extract the 5-th bits for x0 and x1 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; @@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; + reg_f[i/2][2*(i%2) + 0] = d * x0 + m; + reg_f[i/2][2*(i%2) + 1] = d * x1 + m; } + + reg = (type4x4) reg_f; } template @@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; + float4x4 reg_f; + for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); + reg_f[i/4][i%4] = (qs[i + 16*il] * d); } + + reg = (type4x4) reg_f; } template @@ -2768,6 +2788,7 @@ template< typename s_t, // attention accumulation types typename s8x8_t, typename o_t, + typename o4_t, typename o8x8_t, typename block_q, short nl_k, @@ -2835,6 +2856,7 @@ kernel void kernel_flash_attn_ext( threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory @@ -2854,7 +2876,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < ne01) { sq4[j*T4 + i] = (q4_t) q4[i]; } else { - sq4[j*T4 + i] = 0.0h; + sq4[j*T4 + i] = (q4_t) (float4) 0.0f; } } } @@ -2867,7 +2889,7 @@ kernel void kernel_flash_attn_ext( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*TS + i] = (s_t) 0.0f; } } @@ -3024,12 +3046,12 @@ kernel void kernel_flash_attn_ext( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; + ss[j*TS + tiisg] = (s_t) vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*TS + C + tiisg] = ms[tiisg]; + ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg]; } } @@ -3114,8 +3136,8 @@ kernel void kernel_flash_attn_ext( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; + ss[j*TS + 0] = (s_t) S[j]; + ss[j*TS + 1] = (s_t) M[j]; } } } @@ -3139,11 +3161,11 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TS + 0]; - const float S1 = ss[j*TS + sg*SH + 0]; + const float S0 = (s_t) ss[j*TS + 0]; + const float S1 = (s_t) ss[j*TS + sg*SH + 0]; - const float M0 = ss[j*TS + 1]; - const float M1 = ss[j*TS + sg*SH + 1]; + const float M0 = (s_t) ss[j*TS + 1]; + const float M1 = (s_t) ss[j*TS + sg*SH + 1]; M = max(M0, M1); @@ -3153,11 +3175,11 @@ kernel void kernel_flash_attn_ext( S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; + ss[j*TS + 0] = (s_t) S; + ss[j*TS + 1] = (s_t) M; - ss[j*TS + C + j ] = ms0; - ss[j*TS + C + j + sg*SH] = ms1; + ss[j*TS + C + j ] = (s_t) ms0; + ss[j*TS + C + j + sg*SH] = (s_t) ms1; } } @@ -3196,7 +3218,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S; } } } @@ -3213,7 +3235,7 @@ kernel void kernel_flash_attn_ext( half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, simdgroup_half8x8, \ - half, simdgroup_half8x8 + half, half4, simdgroup_half8x8 #else #define S_T float #define S4_T float4 @@ -3225,9 +3247,11 @@ kernel void kernel_flash_attn_ext( half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - half, simdgroup_half8x8 + half, half4, simdgroup_half8x8 #endif +// TOOD: static_assert + typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;