From 44f30434aa8c0ac6e0986ceb17c0cc386ff03499 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 3 Jan 2024 17:29:27 +0100 Subject: [PATCH] fixup! CUDA: faster softmax via shared memory + fp16 math --- ggml-cuda.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d5483ac9b..b58af8040 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5233,7 +5233,7 @@ template = CC_PASCAL const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; - const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE); + const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; const int tid = threadIdx.x; const int rowx = blockIdx.x; @@ -5270,7 +5270,9 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } else { val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); } - vals[col_smem] = val; + if (!need_check || col_smem < ncols_smem) { + vals[col_smem] = val; + } max_val = __hmax2(max_val, val); } @@ -6730,7 +6732,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half); + const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); if (shmem <= g_device_caps[g_main_device].smpb) { switch (ncols_x) { @@ -6773,7 +6775,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float); + const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); if (shmem < g_device_caps[g_main_device].smpb) { switch (ncols_x) {