From 972c2adc15b5d61c2b3f267989a3185d2a99ce46 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 24 Jan 2024 16:41:57 -0500 Subject: [PATCH] use half2 instead half4 --- ggml-cuda.cu | 197 ++++++++++++++++++++------------------------------- 1 file changed, 77 insertions(+), 120 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9d2b99ac9..e9657dd88 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5992,7 +5992,7 @@ static __global__ void im2col_f32_f16( #define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 -template +template static __global__ void flash_attn_f32( const float* __restrict__ q, const float* __restrict__ k, @@ -6004,9 +6004,9 @@ static __global__ void flash_attn_f32( const int head_size = head_dim * seq_len; const int s = blockIdx.x % seq_len; - extern __shared__ char shmem__[]; - float* S = (float*)shmem__; - float* warp_data = (float*)(shmem__ + seq_len * sizeof(float)); + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); // QK^T #pragma unroll @@ -6019,11 +6019,11 @@ static __global__ void flash_attn_f32( const int key_offset = is * head_dim + head * head_size; const int query_offset = s * head_dim + head * head_size; - S[is] = 0.0f; + float tmp = 0.0f; for(int d = 0; d < head_dim; d++) { - S[is] += k[key_offset + d] * q[query_offset + d]; + tmp += k[key_offset + d] * q[query_offset + d]; } - S[is] *= kq_scale; + S[is] = tmp * kq_scale; } __syncthreads(); @@ -6060,9 +6060,9 @@ static __global__ void flash_attn_f32( if(is >= seq_len) { break; } - - S[is] = expf(S[is] - max_val); - sum += S[is]; + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; } __syncthreads(); @@ -6091,7 +6091,12 @@ static __global__ void flash_attn_f32( __syncthreads(); // softmax(QK^T)V - for (int d = threadIdx.x; d < head_dim; d += block_size) { + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } const int dst_index = d + s * head_dim + head * head_size; const int value_offset = d * seq_len + head * head_size; @@ -6107,51 +6112,8 @@ static __global__ void flash_attn_f32( } } -struct __align__(8) half4 { - half x; - half y; - half z; - half w; -}; - -__device__ half4 make_half4(half x) { - half4 t; - t.x = x; t.y = x; t.z = x; t.w = x; - return t; -} - -__device__ half4 __h4fma(half4 a, half b, half4 c) { - half4 t; - t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w); - return t; -} - -__device__ half4 __h4fma(half4 a, half4 b, half4 c) { - half4 t; - t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w); - return t; -} - -__device__ half4 __h4mul(half4 a, half b) { - half4 t; - t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b); - return t; -} - -__device__ half4 __h4mul(half4 a, half4 b) { - half4 t; - t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w); - return t; -} - -__device__ half4 __h4div(half4 a, half b) { - half4 t; - t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b); - return t; -} - // based on metal version -template // head size, rows per block +template // D head size, R rows per block static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, @@ -6205,91 +6167,93 @@ static __global__ void flash_attn_ext_f16( // const int iv2 = iq2 / rv2; // const int iv3 = iq3 / rv3; + const half2 scale_h = __half2half2(__float2half(scale)); + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; - extern __shared__ char shmem__[]; + extern __shared__ char data_flash_attn_shmem[]; - half4* pq4 = (half4*)shmem__; - half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); - half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); + half2* pq2 = (half2*)data_flash_attn_shmem; + half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D); + half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D); - const int tiih = lane_id % tph; // thread index in head - const int hiisg = lane_id / tph; // head index in warp + const int tiih = lane_id % tph; // thread index in head + const int hiiw = lane_id / tph; // head index in warp - const int D4 = D/4; + const int D2 = D / 2; // number of half2 to store head_dim row // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { + for (int i = 0; i < D2/tph; ++i) { if (warp_id == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; } - ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0); + ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0); } __syncthreads(); - half S(0.0); - half M(-INFINITY); + half2 S = make_half2(0.0, 0.0); + half2 M = make_half2(-INFINITY, -INFINITY); for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half mv = mp ? mp[ic] : 0.0; - if (__hisinf(mv) == -1) { // mv == -INFINITY + const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); + if (__hisinf(mv.x) == -1) { // mv == -INFINITY continue; } - const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K + half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K - half4 s4 = make_half4(0.0); + half2 s2 = make_half2(0.0, 0.0); #pragma unroll - for (int i = 0; i < D4/tph; ++i) { - s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4); + for (int i = 0; i < D2/tph; ++i) { + s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2; } - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y); __syncthreads(); if (tiih == 0) { - half s = 0.0; + half2 s = make_half2(0.0, 0.0); #pragma unroll for (int i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + s += ss[hiiw*tph + i]; } - s = __hfma(s, __float2half(scale), mv); // s*scale + mv + s = s * scale_h + mv; // s*scale + mv - const half m = M; + half2 m = M; - M = __hmax(M, s); + M = __hmax2(M, s); - const half ms = hexp(m - M); - const half vs = hexp(s - M); + half2 ms = h2exp(m - M); + half2 vs = h2exp(s - M); - S = __hfma(S, ms, vs); + S = S * ms + vs; - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; + ss[2*hiiw + 0] = ms; + ss[2*hiiw + 1] = vs; } __syncthreads(); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + half2 ms = ss[2*hiiw + 0]; + half2 vs = ss[2*hiiw + 1]; #pragma unroll - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs)); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs; } } if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + ss[2*hiiw + 0] = S; + ss[2*hiiw + 1] = M; } __syncthreads(); @@ -6297,31 +6261,31 @@ static __global__ void flash_attn_ext_f16( // reduce the warps if (warp_id == 0) { for (int sg = 1; sg < nwraps; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + half2 S0 = ss[ 2*hiiw + 0]; + half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + half2 M0 = ss[ 2*hiiw + 1]; + half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1]; - M = __hmax(M0, M1); + M = __hmax2(M0, M1); - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); + half2 ms0 = h2exp(M0 - M); + half2 ms1 = h2exp(M1 - M); - S = __hfma(S0, ms0, __hmul(S1, ms1)); + S = S0 * ms0 + S1 * ms1; if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + ss[2*hiiw + 0] = S; + ss[2*hiiw + 1] = M; } - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1)); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1; } } - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S); } } @@ -6332,17 +6296,10 @@ static __global__ void flash_attn_ext_f16( const int i2 = iq2; const int i3 = iq3; - float4 * dst4 = (float4 *) kqv; - + float2 * dst2 = (float2 *) kqv; if (warp_id == 0) { - for (int i = 0; i < D4/tph; ++i) { - float4 dst_ = - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih]; - half4 src_ = ps4[hiisg*D4 + tph*i + tiih]; - dst_.x = __half2float(src_.x); - dst_.y = __half2float(src_.y); - dst_.z = __half2float(src_.z); - dst_.w = __half2float(src_.w); + for (int i = 0; i < D2/tph; ++i) { + dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); } } } @@ -7741,7 +7698,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); int num_blocks = num_heads * seq_len; - flash_attn_f32<<>>( + flash_attn_f32<<>>( q, k, v, dst, kq_scale, d_head, seq_len, num_heads); } @@ -10342,11 +10299,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps = 32; const int nhpw = 2; // heads per warp - dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); - dim3 block_dim(32, nwarps, 1); - - int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2); + dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]); + dim3 block_dim(32 * nwarps, 1, 1); + int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2); + printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { case 64: