From a734da71ce4b1d88cbb09454adaaba6ce6d6e9c6 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Thu, 21 Nov 2024 13:32:48 +0100 Subject: [PATCH] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/argmax.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 94ab5df05..b7740ddc2 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -5,7 +5,7 @@ #include "common.cuh" #include "sum.cuh" -static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols) { +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { const int64_t row = blockIdx.x; float maxval = -FLT_MAX; @@ -30,7 +30,7 @@ static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t } } - const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + const int n_warps = blockDim.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE; if (n_warps > 1) {