From e1936bb52f14c346ff069c7d745a467b3cdc0365 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 3 Jan 2024 16:56:27 +0100 Subject: [PATCH] fixup! fixup! CUDA: faster softmax via shared memory + fp16 math --- ggml-cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8c47ab059..d5483ac9b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5247,7 +5247,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds extern __shared__ half data_soft_max_f16[]; half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication // (shared memory) buffer to cache values between iterations: - half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data); + half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); half2 max_val = make_half2(-INFINITY, -INFINITY); @@ -5371,7 +5371,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols; + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; float max_val = -INFINITY;