diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b7ebfcc57..1b11a34bb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6113,7 +6113,7 @@ static __global__ void flash_attn_f32( } // based on metal version -template // D head size, R rows per block +template // D head size, Q queries per block, C cache items per blocks static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, @@ -6141,62 +6141,64 @@ static __global__ void flash_attn_ext_f16( int ne1, int ne2, int ne3) { - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; - const int nwraps = blockDim.y; // number of warps - const int tph = WARP_SIZE / R; // threads per head + const int n_warps = blockDim.y; // number of warps const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y * R + lane_id / tph; - const int iq1 = blockIdx.x; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; - if(iq2 >= ne02) { - return; - } + const int D2 = D/2; + const int N4 = WARP_SIZE; + const int L2 = (D2 + N4 - 1)/N4; + const int D8 = D/8; - // broadcast - const int rk2 = ne02 / ne12; - const int rk3 = ne03 / ne13; - // assume the same K and V shape - // const int rv2 = ne02 / ne12; - // const int rv3 = ne03 / ne13; - - // kv indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; - // const int iv2 = iq2 / rv2; - // const int iv3 = iq3 / rv3; + const int T = D + n_warps*(D + 1*C); // shared memory size per query in half + const int T2 = T/2; // shared memory size per query in half2 const half2 scale_h = __half2half2(__float2half(scale)); - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; - extern __shared__ char data_flash_attn_shmem[]; - half2* pq2 = (half2*)data_flash_attn_shmem; - half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D); - half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D); + half * pq = (half *) (data_flash_attn_shmem + 0*D); + half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); + half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); + half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); + half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); - const int tiih = lane_id % tph; // thread index in head - const int hiiw = lane_id / tph; // head index in warp - - const int D2 = D / 2; // number of half2 to store head_dim row - - // load R heads from Q to shared memory - for (int i = 0; i < D2/tph; ++i) { - if (warp_id == 0) { - pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + for (int i = 0; i < L2; ++i) { + // load heads from Q to shared memory + for (int j = warp_id; j < Q; j += n_warps) { + if (iq1 + j < ne01) { + pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id]; + } else { + pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + } } - ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0); + // zero out shared memory + for (int j = 0; j < Q; ++j) { + ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + } } + + if (lane_id < C) { + for (int j = 0; j < Q; ++j) { + ss[j*T + 0 + lane_id] = 0.0; + } + } + __syncthreads(); - half2 S = make_half2(0.0, 0.0); + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + half S[8] = { 0.0 }; +#if 0 half2 M = make_half2(-INFINITY, -INFINITY); + const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; + for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); if (__hisinf(mv.x) == -1) { // mv == -INFINITY @@ -6302,6 +6304,7 @@ static __global__ void flash_attn_ext_f16( dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); } } +#endif } @@ -10296,18 +10299,19 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = 32; - const int nhpw = 2; // heads per warp + const int nwarps = Q->ne[1] < 4 ? 4 : 2; + const int nqpb = 2; // queries per block + const int ncpw = 32; // cache values per warp (does not work for other values) - dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]); - dim3 block_dim(32 * nwarps, 1, 1); + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); - int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2); + int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2); printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { case 64: - flash_attn_ext_f16<64, 2> + flash_attn_ext_f16<64, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10324,7 +10328,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 2> + flash_attn_ext_f16<80, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10341,7 +10345,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 2> + flash_attn_ext_f16<128, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key