fix excessive KQ_b loads
This commit is contained in:
parent
e1ecd3b129
commit
bb0d51accd
1 changed files with 8 additions and 4 deletions
|
@ -387,12 +387,16 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n];
|
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) {
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||||
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded);
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
|
nvcuda::wmma::load_matrix_sync(
|
||||||
|
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||||
|
KQ + j0*kqs_padded + k,
|
||||||
|
kqs_padded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,7 +416,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue