diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d1d018dc7..dcd2129a2 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { - nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*kqs_padded + k, + kqs_padded); } } @@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } }