From 316f3d31163e3cf7ea85ea73f2d2d15720b56115 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 21 Nov 2024 13:48:43 +0100 Subject: [PATCH] fix ub --- ggml/src/ggml-cuda/argmax.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index b7740ddc2..5340eedc0 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -44,14 +44,15 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest __syncthreads(); - if (warp_id == 0 && lane_id < n_warps) { - maxval = shared_maxval[lane_id]; - argmax = shared_argmax[lane_id]; - const unsigned int mask = (1u << n_warps) - 1u; + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col;