From 49a483e0f27da71d58948727711ea619aed56734 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 12:34:36 +0200 Subject: [PATCH] wip --- ggml-cuda.cu | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 713a6a89a..a2ae418cf 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6565,7 +6565,7 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { - const int ic = ic0 + warp_id*C; + const int ic = ic0 + warp_id*16; if (ic >= ne11) { break; } @@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(mqk[j], 0); } - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + const half * pk = (const half *) ((const char *) k + ((ic + 16*num_warps*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key @@ -6596,7 +6596,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc mm; if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*num_warps*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } // convert accumulator to matrix_a @@ -6686,22 +6686,25 @@ static __global__ void flash_attn_ext_f16( // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + half16x16_b mv[C16][D16]; + for (int i = 0; i < D16; ++i) { + for (int cc = 0; cc < C16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*num_warps*cc)*256 + iv2*nb22 + iv3*nb23)); - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[cc][i], pv + i*16, 256/sizeof(half)); } + } + for (int cc = 0; cc < C16; ++cc) { half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { +#pragma unroll for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[cc][i], lo[j][i]); } } }