diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8c47ab059..d5483ac9b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5247,7 +5247,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds extern __shared__ half data_soft_max_f16[]; half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication // (shared memory) buffer to cache values between iterations: - half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data); + half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); half2 max_val = make_half2(-INFINITY, -INFINITY); @@ -5371,7 +5371,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols; + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; float max_val = -INFINITY;