Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Diego Devesa 2024-11-21 13:32:48 +01:00 committed by GitHub
parent 1e9447a00b
commit a734da71ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,7 +5,7 @@
#include "common.cuh" #include "common.cuh"
#include "sum.cuh" #include "sum.cuh"
static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols) { static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
const int64_t row = blockIdx.x; const int64_t row = blockIdx.x;
float maxval = -FLT_MAX; float maxval = -FLT_MAX;
@ -30,7 +30,7 @@ static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t
} }
} }
const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; const int n_warps = blockDim.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
if (n_warps > 1) { if (n_warps > 1) {