From 2455a8d6c3b2e49cc19155aeb8e12438fd6a42fa Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 27 Jan 2024 12:23:40 -0500 Subject: [PATCH] update impl --- ggml-cuda.cu | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5cb065606..ecfa98c4e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32( } } -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version template // D head size, Q queries per block, C cache items per blocks @@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16( for (int i = 0; i < L2; ++i) { // load heads from Q to shared memory for (int j = warp_id; j < Q; j += n_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id]; + pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]); } else { pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); } @@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q] = { 0.0 }; - half M[Q] = { -INFINITY }; + half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? + half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization // assume K and V are same shape const int ne22 = ne12; @@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { - half16x16_a mq{}; - half16x16_b mk{}; - half16x16_acc mqk{}; + half16x16_a mq; + half16x16_b mk; + half16x16_acc mqk; for (int cc = 0; cc < C/16; ++cc) { - nvcuda::wmma::fill_fragment(mqk, 0); // re fetch + nvcuda::wmma::fill_fragment(mqk, 0); const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16( } // online softmax - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); @@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16( S[j] = S[j]*ms + warp_reduce_sum(vs); + for (int i = 0; i < L2; ++i) { + ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms); + } + ss[j*T + p] = vs; } @@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16( // (Q*K^T)*V { - half16x16_acc mqkv{}; - half16x16_a mqk{}; - half16x16_b mv{}; + half16x16_acc mqkv; + half16x16_a mqk; + half16x16_b mv; for (int64_t i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(mqkv, 0); @@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16( // TODO: try parallel reduce if (warp_id == 0) { half S = 0.0; - half M = -INFINITY; + half M = __float2half(-INFINITY); for (int64_t sg = 1; sg < n_warps; ++sg) { for (int64_t j = 0; j < Q; ++j) { @@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16( } } } - } - template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { - GGML_ASSERT(Q->type == GGML_TYPE_F16); + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(mask->type == GGML_TYPE_F32); @@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = Q->ne[1] < 4 ? 4 : 2; + const int nwarps = Q->ne[1] < 4 ? 12 : 4; const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values)