diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0d23c1244..bdd50e2b6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + 2*Q); // shared memory per simdgroup in (half) + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16( half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory @@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16( } } + nvcuda::wmma::fill_fragment(zr, 0.0); + // zero out lo for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { @@ -6648,13 +6652,15 @@ static __global__ void flash_attn_ext_f16( for (int64_t i = 0; i < D16; ++i) { // convert accumulator to matrix_b - // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T); + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::fill_fragment(lo[j][i], 0.0); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } // O = O + (Q*K^T)*V @@ -10928,14 +10934,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int ncpw = 32; // cache values per warp (does not work for other values) const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling - // try to avoid this - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2); + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); switch (Q->ne[0]) {