From 71b69aa7fd0aee89c4750d230bee7a4601d8fc1f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 09:40:56 +0200 Subject: [PATCH] cuda : fix flash_attn kernel to produce same results as CPU --- ggml-cuda.cu | 66 +++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 86afb0133..0d23c1244 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int SH = (C + 2*Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { @@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - // TODO: process mask - mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; - } + half16x16_a mqka; + half16x16_acc mm; + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } @@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16( // O = diag(ms)*O for (int64_t j = 0; j < Q16; ++j) { - // half16x16_a mm; - // half16x16_b zro; + half16x16_a mm; + half16x16_b lob; - // nvcuda::wmma::fill_fragment(zro, 0.0); - // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); - for (uint32_t k = 0; k < 16*16; k++) { - half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; - lo[j][i].x[k] = tmp * lo[j][i].x[k]; - } + // convert accumulator to matrix_b + // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); } } @@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); - // store temporally 'lo' data - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - // load 'lo' data into t - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } @@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) - // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; - const int nwarps = 1; + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling + // try to avoid this + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2); + switch (Q->ne[0]) { case 16: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c..e632142a7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int hs : { 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, 512 }) { + for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) { test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); }