From 865af990cc5426f9bee37cca3d5670edd41eeaae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 May 2024 14:50:28 +0300 Subject: [PATCH] ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci --- ggml-cuda/fattn.cu | 72 ++++++++++++++++++++++++++++++++++++++------ ggml-cuda/softmax.cu | 8 ++--- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 7c486f482..ac5d6672b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); + half slopeh = __float2half(1.0f); + half2 slope2 = make_half2(1.0f, 1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + slope2 = make_half2(slopeh, slopeh); + } + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -710,8 +744,17 @@ template void launch_fattn_vec_ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_vec_ext_f16 <<>> ( @@ -720,7 +763,7 @@ template void launch_fattn_vec_ (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -761,8 +804,17 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_ext_f16 <<>> ( @@ -771,7 +823,7 @@ template data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const int32_t precision = KQV->op_params[1]; + const int32_t precision = KQV->op_params[2]; if (!fp16_mma_available(cc)) { GGML_ASSERT(precision == GGML_PREC_DEFAULT); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 917f15a07..ca85285a3 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -30,9 +30,9 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int h = rowx/nrows_y; // head index const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = powf(base, exp); + slope = powf(base, exph); } extern __shared__ float data_soft_max_f32[]; @@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);