This commit is contained in:
slaren 2024-11-21 13:48:43 +01:00
parent a734da71ce
commit 316f3d3116

View file

@ -44,14 +44,15 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
__syncthreads(); __syncthreads();
if (warp_id == 0 && lane_id < n_warps) { if (warp_id == 0) {
if (lane_id < n_warps) {
maxval = shared_maxval[lane_id]; maxval = shared_maxval[lane_id];
argmax = shared_argmax[lane_id]; argmax = shared_argmax[lane_id];
const unsigned int mask = (1u << n_warps) - 1u; }
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) { for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE); const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) { if (val > maxval) {
maxval = val; maxval = val;
argmax = col; argmax = col;