This commit is contained in:
Georgi Gerganov 2024-02-04 12:34:36 +02:00
parent 1846e92a90
commit 49a483e0f2
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6565,7 +6565,7 @@ static __global__ void flash_attn_ext_f16(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
const int ic = ic0 + warp_id*C;
const int ic = ic0 + warp_id*16;
if (ic >= ne11) {
break;
}
@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(mqk[j], 0);
}
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
const half * pk = (const half *) ((const char *) k + ((ic + 16*num_warps*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
@ -6596,7 +6596,7 @@ static __global__ void flash_attn_ext_f16(
half16x16_acc mm;
if (mp) {
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*num_warps*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
}
// convert accumulator to matrix_a
@ -6686,22 +6686,25 @@ static __global__ void flash_attn_ext_f16(
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16];
half16x16_b mv[C16][D16];
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*num_warps*cc)*256 + iv2*nb22 + iv3*nb23));
nvcuda::wmma::load_matrix_sync(mv[cc][i], pv + i*16, 256/sizeof(half));
}
}
for (int cc = 0; cc < C16; ++cc) {
half16x16_a ms[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
}
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[cc][i], lo[j][i]);
}
}
}