From 7fd024c5e92497f09467ad6fd87035a99bc5f7bc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 15 Feb 2024 18:25:48 +0200 Subject: [PATCH] cuda : precompute ALiBi constants --- ggml-cuda.cu | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 79487391c..a107cbe7e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5957,7 +5957,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) { +static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -5973,15 +5973,12 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f // ALiBi if (max_bias > 0.0f) { - const uint32_t n_head_kv = gridDim.x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int h = rowx/nrows_y; // head index - slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = powf(base, exp); } extern __shared__ float data_soft_max_f32[]; @@ -7482,39 +7479,46 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float * const dim3 block_nums(nrows_x, 1, 1); const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + if (shmem < g_device_caps[g_main_device].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias); + soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } }