diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 96976f248..562595b05 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5957,7 +5957,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5971,6 +5971,21 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; + float slope = 0.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t n_head_kv = gridDim.x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); + + const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); + const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int h = rowx/nrows_y; // head index + + slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + } + extern __shared__ half data_soft_max_f16[]; half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication // (shared memory) buffer to cache values between iterations: @@ -5992,12 +6007,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f) + slope*col_data; } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f) + slope*(col_data + WARP_SIZE); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6087,7 +6102,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6099,6 +6114,21 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; + float slope = 0.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t n_head_kv = gridDim.x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); + + const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); + const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int h = rowx/nrows_y; // head index + + slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + } + extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: @@ -6117,7 +6147,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? y[iy] : 0.0f) + slope*col; + vals[col] = val; max_val = max(max_val, val); } @@ -7589,7 +7620,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7599,40 +7630,40 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con if (shmem <= g_device_caps[g_main_device].smpb) { switch (ncols_x) { case 32: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 64: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 128: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 256: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 512: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 1024: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 2048: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 4096: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; default: - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(half); - soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7642,36 +7673,36 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con if (shmem < g_device_caps[g_main_device].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 64: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 128: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 256: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 512: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 1024: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 2048: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; case 4096: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; default: - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale, max_bias); } } @@ -9091,10 +9122,13 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; - memcpy(&scale, dst->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX #ifdef GGML_CUDA_F16 @@ -9107,9 +9141,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } (void) dst; diff --git a/ggml-metal.metal b/ggml-metal.metal index 19b02880d..4b6514014 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -373,6 +373,7 @@ kernel void kernel_soft_max( float slope = 0.0f; + // ALiBi if (max_bias > 0.0f) { const uint32_t n_head_kv = ne02; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); diff --git a/ggml.c b/ggml.c index ea1b31d9f..6cd5e47e1 100644 --- a/ggml.c +++ b/ggml.c @@ -11516,7 +11516,7 @@ static void ggml_compute_forward_soft_max_f32( ggml_vec_acc_f32(nc, wp, mp); } - // alibi bias + // ALiBi bias if (max_bias > 0.0f) { const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 56e5fa920..e2e676ba7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2093,7 +2093,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int n = 0; n < 10; ++n) { int64_t ne0 = dist_ne0(rng); int64_t ne1 = dist_ne1(rng); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, 4.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f)); } exponent <<= 1;