diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0eeee7484..940ffbfc8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -937,6 +937,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6106,6 +6107,211 @@ static __global__ void flash_attn_f32( } } +struct __align__(8) half4 { + half x; + half y; + half z; + half w; +}; + +// based on metal version +template // head size, rows per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ kqv, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + + const int nwraps = blockDim.y; // number of warps + const int tph = WARP_SIZE / R; // threads per head + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y * R + lane_id / tph; + const int iq1 = blockIdx.x; + + if(iq2 >= ne02) { + return; + } + + // broadcast + const int rk2 = ne02 / ne12; + const int rk3 = ne03 / ne13; + // assume the same K and V shape + // const int rv2 = ne02 / ne12; + // const int rv3 = ne03 / ne13; + + // kv indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; + + extern __shared__ char shmem__[]; + + half4* pq4 = (half4*)shmem__; + half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); + half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); + + const int tiih = lane_id % tph; // thread index in head + const int hiisg = lane_id / tph; // head index in warp + + const int D4 = D/4; + + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + if (warp_id == 0) { + pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; + } + + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; + } + __syncthreads(); + + half S = 0.0h; + half M = -INFINITY; + + for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K + + half4 s4 = 0.0h; + +#pragma unroll + for (int i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; + } + + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + + __syncthreads(); + + if (tiih == 0) { + half s = 0.0h; + +#pragma unroll + for (int i = 0; i < tph; ++i) { + s += ss[hiisg*tph + i]; + } + + s = s*scale + mv; + + const half m = M; + + M = max(M, s); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; + } + + __syncthreads(); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + +#pragma unroll + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + } + } + + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + __syncthreads(); + + // reduce the warps + if (warp_id == 0) { + for (int sg = 1; sg < nwraps; ++sg) { + const half S0 = ss[ 2*hiisg + 0]; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = ss[ 2*hiisg + 1]; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + } + } + + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; + } + } + + __syncthreads(); + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + float4 * dst4 = (float4 *) kqv; + + if (warp_id == 0) { + for (int i = 0; i < D4/tph; ++i) { + float4 dst_ = + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih]; + half4 src_ = ps4[hiisg*D4 + tph*i + tiih]; + dst_.x = __half2float(src_.x); + dst_.y = __half2float(src_.y); + dst_.z = __half2float(src_.z); + dst_.w = __half2float(src_.w); + } + } +} + + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -10071,6 +10277,98 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c KQ_scale, d_head, sequence_length, num_heads, main_stream); } + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F16); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(mask->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + const int nwarps = 32; + const int nhpw = 2; // heads per warp + + dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 64: + flash_attn_ext_f16<64, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10341,6 +10639,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st break; case GGML_OP_FLASH_ATTN: break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10357,7 +10657,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st } if(tensor->op == GGML_OP_FLASH_ATTN) { ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); - } else { + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { func(tensor->src[0], tensor->src[1], tensor); } return true; @@ -11175,6 +11477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false;