fix excessive KQ_b loads

This commit is contained in:
Johannes Gäßler 2024-04-02 11:13:46 +02:00 committed by Georgi Gerganov
parent e1ecd3b129
commit bb0d51accd

View file

@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n]; frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded); const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*kqs_padded + k,
kqs_padded);
} }
} }
@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
} }
} }
} }