diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index b7740ddc2..5340eedc0 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -44,14 +44,15 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest __syncthreads(); - if (warp_id == 0 && lane_id < n_warps) { - maxval = shared_maxval[lane_id]; - argmax = shared_argmax[lane_id]; - const unsigned int mask = (1u << n_warps) - 1u; + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE); - const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE); + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { maxval = val; argmax = col;