diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1b11a34bb..5cb065606 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -104,6 +104,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -621,6 +622,14 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ __half warp_reduce_sum(__half x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -642,6 +651,19 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; + bad_arch(); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -6112,6 +6134,10 @@ static __global__ void flash_attn_f32( } } +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_acc; + // based on metal version template // D head size, Q queries per block, C cache items per blocks static __global__ void flash_attn_ext_f16( @@ -6152,17 +6178,17 @@ static __global__ void flash_attn_ext_f16( const int D2 = D/2; const int N4 = WARP_SIZE; const int L2 = (D2 + N4 - 1)/N4; - const int D8 = D/8; + const int D16 = D/16; const int T = D + n_warps*(D + 1*C); // shared memory size per query in half - const int T2 = T/2; // shared memory size per query in half2 + const int T2 = T/2; // shared memory size per query in half2 - const half2 scale_h = __half2half2(__float2half(scale)); + const half scale_h = __float2half(scale); extern __shared__ char data_flash_attn_shmem[]; - - half * pq = (half *) (data_flash_attn_shmem + 0*D); - half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); + // pq + half * pq = (half *) (data_flash_attn_shmem + 0*D); + half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); @@ -6191,120 +6217,185 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + { + half S[Q] = { 0.0 }; + half M[Q] = { -INFINITY }; - half S[8] = { 0.0 }; -#if 0 - half2 M = make_half2(-INFINITY, -INFINITY); + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; - const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; - for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); - if (__hisinf(mv.x) == -1) { // mv == -INFINITY - continue; + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // TODO: this can be improved + float * mp[Q]; + + { + const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + for (int j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31); + } else { + mp[j] = nullptr; + } + } } - half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K + for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - half2 s2 = make_half2(0.0, 0.0); + for (int j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY; + smc = warp_reduce_max(max(smc, mc)); + } -#pragma unroll - for (int i = 0; i < D2/tph; ++i) { - s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2; - } - - ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y); - - __syncthreads(); - - if (tiih == 0) { - half2 s = make_half2(0.0, 0.0); - -#pragma unroll - for (int i = 0; i < tph; ++i) { - s += ss[hiiw*tph + i]; + if (smc == -INFINITY) { + continue; + } } - s = s*scale_h + mv; // s*scale + mv + // Q*K^T + { + half16x16_a mq{}; + half16x16_b mk{}; + half16x16_acc mqk{}; - half2 m = M; + for (int cc = 0; cc < C/16; ++cc) { + nvcuda::wmma::fill_fragment(mqk, 0); // re fetch - M = __hmax2(M, s); + const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - half2 ms = h2exp(m - M); - half2 vs = h2exp(s - M); + for(int i = 0; i < D16;i ++) { + nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T); + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + nvcuda::wmma::mma_sync(mqk, mq, mk, mqk); + } - S = S*ms + vs; + nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major); + } + } - ss[2*hiiw + 0] = ms; - ss[2*hiiw + 1] = vs; + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; + + const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); + + half m = M[j]; + + M[j] = warp_reduce_max(__hmax(M[j], s)); + + const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]); + + S[j] = S[j]*ms + warp_reduce_sum(vs); + + ss[j*T + p] = vs; + } + + __syncthreads(); + + // (Q*K^T)*V + { + half16x16_acc mqkv{}; + half16x16_a mqk{}; + half16x16_b mv{}; + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(mqkv, 0); + + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T); + nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half)); + + nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv); + } + + nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major); + } + } } - __syncthreads(); - - half2 ms = ss[2*hiiw + 0]; - half2 vs = ss[2*hiiw + 1]; - -#pragma unroll - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs; + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } } } - if (tiih == 0) { - ss[2*hiiw + 0] = S; - ss[2*hiiw + 1] = M; - } - __syncthreads(); // reduce the warps + // TODO: try parallel reduce if (warp_id == 0) { - for (int sg = 1; sg < nwraps; ++sg) { - half2 S0 = ss[ 2*hiiw + 0]; - half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0]; + half S = 0.0; + half M = -INFINITY; - half2 M0 = ss[ 2*hiiw + 1]; - half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1]; + for (int64_t sg = 1; sg < n_warps; ++sg) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - M = __hmax2(M0, M1); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - half2 ms0 = h2exp(M0 - M); - half2 ms1 = h2exp(M1 - M); + M = __hmax(M0, M1); - S = S0*ms0 + S1*ms1; + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); - if (tiih == 0) { - ss[2*hiiw + 0] = S; - ss[2*hiiw + 1] = M; + S = S0*ms0 + S1*ms1; + + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L2; ++i) { + ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1); + } } - - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1; - } - } - - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S); } } __syncthreads(); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - float2 * dst2 = (float2 *) kqv; + if (warp_id == 0) { - for (int i = 0; i < D2/tph; ++i) { - dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { + half2 S = __half2half2(ss[j*T + 0]); + + for (int i = 0; i < L2; ++i) { + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S); + } } } -#endif + } @@ -10300,7 +10391,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * memcpy(&scale, KQV->op_params, sizeof(float)); const int nwarps = Q->ne[1] < 4 ? 4 : 2; - const int nqpb = 2; // queries per block + const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); @@ -10311,7 +10402,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * switch (Q->ne[0]) { case 64: - flash_attn_ext_f16<64, 8, 32> + flash_attn_ext_f16<64, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10328,7 +10419,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 8, 32> + flash_attn_ext_f16<80, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10345,7 +10436,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 8, 32> + flash_attn_ext_f16<128, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key