cuda : avoid extra QxQ matrix in shared memory

This commit is contained in:
Georgi Gerganov 2024-02-01 14:03:03 +02:00
parent 71b69aa7fd
commit 2c04beeb81
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
const int D16 = D/16; const int D16 = D/16;
const int Q16 = Q/16; const int Q16 = Q/16;
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + 2*Q); // shared memory per simdgroup in (half) const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half16x16_acc zr;
half16x16_acc lo[Q16][D16]; half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory // load heads from Q to shared memory
@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
} }
} }
nvcuda::wmma::fill_fragment(zr, 0.0);
// zero out lo // zero out lo
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
@ -6648,13 +6652,15 @@ static __global__ void flash_attn_ext_f16(
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
// convert accumulator to matrix_b // convert accumulator to matrix_b
// TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);
nvcuda::wmma::fill_fragment(lo[j][i], 0.0); nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
} }
// restore zeros
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
} }
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
@ -10928,14 +10934,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const int ncpw = 32; // cache values per warp (does not work for other values) const int ncpw = 32; // cache values per warp (does not work for other values)
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32, nwarps, 1);
// TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// try to avoid this
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);
switch (Q->ne[0]) switch (Q->ne[0])
{ {