cuda : fix flash_attn kernel to produce same results as CPU
This commit is contained in:
parent
fd878f71ed
commit
71b69aa7fd
2 changed files with 42 additions and 26 deletions
62
ggml-cuda.cu
62
ggml-cuda.cu
|
@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int D16 = D/16;
|
const int D16 = D/16;
|
||||||
const int Q16 = Q/16;
|
const int Q16 = Q/16;
|
||||||
const int NW = WARP_SIZE;
|
const int NW = WARP_SIZE;
|
||||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
const int SH = (C + 2*Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||||
const int T2 = T/2; // shared memory size per query in (half2)
|
const int T2 = T/2; // shared memory size per query in (half2)
|
||||||
|
@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||||
|
|
||||||
|
// prepare diagonal scale matrix
|
||||||
|
half16x16_b mscale;
|
||||||
|
for (int i = 0; i < 16; ++i) {
|
||||||
|
ss[i*T + i] = __float2half(scale);
|
||||||
|
}
|
||||||
|
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||||
|
@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int64_t j = 0; j < Q16; ++j) {
|
||||||
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
|
half16x16_a mqka;
|
||||||
// TODO: process mask
|
half16x16_acc mm;
|
||||||
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
|
|
||||||
}
|
// convert accumulator to matrix_a
|
||||||
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
|
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
||||||
|
|
||||||
|
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||||
|
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
|
||||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int64_t j = 0; j < Q16; ++j) {
|
||||||
// half16x16_a mm;
|
half16x16_a mm;
|
||||||
// half16x16_b zro;
|
half16x16_b lob;
|
||||||
|
|
||||||
// nvcuda::wmma::fill_fragment(zro, 0.0);
|
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||||
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int64_t i = 0; i < D16; ++i) {
|
||||||
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
|
// convert accumulator to matrix_b
|
||||||
for (uint32_t k = 0; k < 16*16; k++) {
|
// TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
|
||||||
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
lo[j][i].x[k] = tmp * lo[j][i].x[k];
|
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);
|
||||||
}
|
|
||||||
|
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||||
|
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
nvcuda::wmma::fill_fragment(t2, 0.0);
|
nvcuda::wmma::fill_fragment(t2, 0.0);
|
||||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||||
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
|
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
|
||||||
// store temporally 'lo' data
|
|
||||||
|
// convert accumulator to matrix_b
|
||||||
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
// load 'lo' data into t
|
|
||||||
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
||||||
|
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
|
|
||||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
|
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
|
||||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||||
|
|
||||||
ggml_cuda_set_device(g_main_device);
|
ggml_cuda_set_device(g_main_device);
|
||||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
|
|
||||||
const int nqpb = 16; // queries per block
|
const int nqpb = 16; // queries per block
|
||||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||||
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
|
|
||||||
const int nwarps = 1;
|
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||||
|
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
|
||||||
|
|
||||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
dim3 block_dim(32, nwarps, 1);
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
|
||||||
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
// TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
|
||||||
|
// try to avoid this
|
||||||
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);
|
||||||
|
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0])
|
||||||
{
|
{
|
||||||
case 16:
|
case 16:
|
||||||
|
|
|
@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
for (int hs : { 128, }) {
|
for (int hs : { 128, }) {
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, }) {
|
for (int kv : { 512, 1024, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, 512 }) {
|
for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) {
|
||||||
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue