cuda : fix flash_attn kernel to produce same results as CPU

This commit is contained in:
Georgi Gerganov 2024-02-01 09:40:56 +02:00
parent fd878f71ed
commit 71b69aa7fd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 42 additions and 26 deletions

View file

@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
const int D16 = D/16; const int D16 = D/16;
const int Q16 = Q/16; const int Q16 = Q/16;
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half) const int SH = (C + 2*Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16(
} }
} }
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
// pointer to the mask // pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16(
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
for (uint32_t i = 0; i < mqk[j].num_elements; i++) { half16x16_a mqka;
// TODO: process mask half16x16_acc mm;
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
} // convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
} }
} }
@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16(
// O = diag(ms)*O // O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
// half16x16_a mm; half16x16_a mm;
// half16x16_b zro; half16x16_b lob;
// nvcuda::wmma::fill_fragment(zro, 0.0); nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); // convert accumulator to matrix_b
for (uint32_t k = 0; k < 16*16; k++) { // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
lo[j][i].x[k] = tmp * lo[j][i].x[k]; nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);
}
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
} }
} }
@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2); nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// store temporally 'lo' data
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); // convert accumulator to matrix_b
// load 'lo' data into t nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
} }
} }
@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
ggml_cuda_set_device(g_main_device); ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const int nqpb = 16; // queries per block const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values) const int ncpw = 32; // cache values per warp (does not work for other values)
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
const int nwarps = 1; const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32, nwarps, 1);
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
// try to avoid this
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 16: case 16:

View file

@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int hs : { 128, }) { for (int hs : { 128, }) {
for (int nh : { 32, }) { for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) { for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, 512 }) { for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
} }