From a647257b471067c410bb6a690487b02ae7e79dfa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 11:08:47 +0200 Subject: [PATCH] cuda : express strides with helper constants --- ggml-cuda.cu | 120 +++++++++++++++++++++++++++------------------------ 1 file changed, 63 insertions(+), 57 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 713a6a89a..d672dba2c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6399,10 +6399,16 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; + +// queries, dims, cache per fragment +#define QPF 16 +#define DPF 16 +#define CPF 16 + +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; #endif // based on metal version @@ -6443,9 +6449,9 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D16 = D/16; - const int Q16 = Q/16; - const int C16 = C/16; + const int DPT = D/DPF; // dims per thread + const int QPT = Q/QPF; // queries per thread + const int CPT = C/CPF; // cache per thread const int NW = WARP_SIZE; const int SH = (C + Q); // shared memory per simdgroup in (half) @@ -6463,7 +6469,7 @@ static __global__ void flash_attn_ext_f16( half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; - half16x16_acc lo[Q16][D16]; + half16x16_acc lo[QPT][DPT]; // load heads from Q to shared memory #pragma unroll @@ -6493,8 +6499,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { + for (int j = 0; j < QPT; ++j) { + for (int i = 0; i < DPT; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } @@ -6545,10 +6551,10 @@ static __global__ void flash_attn_ext_f16( const int iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + half16x16_a mq[QPT][DPT]; + for (int j = 0; j < QPT; ++j) { + for (int i = 0; i < DPT; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + QPF*j*T + DPF*i, T); } } @@ -6557,7 +6563,7 @@ static __global__ void flash_attn_ext_f16( // prepare diagonal scale matrix half16x16_b mscale; - for (int i = 0; i < 16; ++i) { + for (int i = 0; i < QPF; ++i) { ss[i*T + i] = __float2half(scale); } nvcuda::wmma::load_matrix_sync(mscale, ss, T); @@ -6573,38 +6579,38 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { #pragma unroll - for (int cc = 0; cc < C16; ++cc) { - half16x16_acc mqk[Q16]; - for (int j = 0; j < Q16; ++j) { + for (int cc = 0; cc < CPT; ++cc) { + half16x16_acc mqk[QPT]; + for (int j = 0; j < QPT; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + const half * pk = (const half *) ((const char *) k + ((ic + CPF*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int i = 0; i < D16; ++i) { + for (int i = 0; i < DPT; ++i) { half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mk, pk + DPF*i, nb11/sizeof(half)); - for (int j = 0; j < Q16; ++j) { + for (int j = 0; j < QPT; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int j = 0; j < Q16; ++j) { + for (int j = 0; j < QPT; ++j) { half16x16_a mqka; half16x16_acc mm; if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync(mm, mp + QPF*j*(nb31/sizeof(half)) + ic + CPF*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } // convert accumulator to matrix_a - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + QPF*j*T + CPF*cc, T); nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } } @@ -6664,43 +6670,43 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int j = 0; j < Q16; ++j) { + for (int j = 0; j < QPT; ++j) { half16x16_a mm; half16x16_b lob; - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(mm, ss + QPF*j*T + C + QPF*j, T); - for (int i = 0; i < D16; ++i) { + for (int i = 0; i < DPT; ++i) { // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + C + QPF*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + QPF*j*T + C + QPF*j, T); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } } // restore zeros - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + for (int j = 0; j < QPT; ++j) { + nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + C + QPF*j, zr, T, nvcuda::wmma::mem_row_major); } // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int cc = 0; cc < CPT; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + CPF*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); + half16x16_b mv[DPT]; + for (int i = 0; i < DPT; ++i) { + nvcuda::wmma::load_matrix_sync(mv[i], pv + DPF*i, nb21/sizeof(half)); } - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); + half16x16_a ms[QPT]; + for (int j = 0; j < QPT; ++j) { + nvcuda::wmma::load_matrix_sync(ms[j], ss + QPF*j*T + CPF*cc, T); } - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { + for (int j = 0; j < QPT; ++j) { + for (int i = 0; i < DPT; ++i) { nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } @@ -6721,9 +6727,9 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + for (int j = 0; j < QPT; ++j) { + for (int i = 0; i < DPT; ++i) { + nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major); } } } @@ -6754,22 +6760,22 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int j = 0; j < Q16; ++j) { + for (int j = 0; j < QPT; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; half16x16_acc t2; - nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + nvcuda::wmma::load_matrix_sync(ms0, ss + QPF*j*T + C + QPF*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + QPF*j*T + C + QPF*j + sg*SH, T); - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + for (int i = 0; i < DPT; ++i) { + nvcuda::wmma::load_matrix_sync(t, sq + QPF*j*T + DPF*i, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + nvcuda::wmma::store_matrix_sync( sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + QPF*j*T + DPF*i, T); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } @@ -6779,9 +6785,9 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + for (int j = 0; j < QPT; ++j) { + for (int i = 0; i < DPT; ++i) { + nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major); } } }