From 1e9447a00b430f0a93daa0d5e7cd074c483950e4 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 21 Nov 2024 02:55:22 +0100 Subject: [PATCH] fixup : use full warps ggml-ci --- ggml/src/ggml-cuda/argmax.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 0523b934a..94ab5df05 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -47,7 +47,7 @@ static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t if (warp_id == 0 && lane_id < n_warps) { maxval = shared_maxval[lane_id]; argmax = shared_argmax[lane_id]; - const unsigned int mask = (1 << n_warps) - 1; + const unsigned int mask = (1u << n_warps) - 1u; #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE); @@ -82,7 +82,8 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); const int64_t num_blocks = nrows; - const dim3 blocks_dim(std::min(ne00, 1024), 1, 1); + const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); const dim3 blocks_num(num_blocks, 1, 1); argmax_f32<<>>(src0_d, dst_d, ne00);