fix ub
This commit is contained in:
parent
a734da71ce
commit
316f3d3116
1 changed files with 7 additions and 6 deletions
|
@ -44,14 +44,15 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (warp_id == 0 && lane_id < n_warps) {
|
if (warp_id == 0) {
|
||||||
|
if (lane_id < n_warps) {
|
||||||
maxval = shared_maxval[lane_id];
|
maxval = shared_maxval[lane_id];
|
||||||
argmax = shared_argmax[lane_id];
|
argmax = shared_argmax[lane_id];
|
||||||
const unsigned int mask = (1u << n_warps) - 1u;
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE);
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE);
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
if (val > maxval) {
|
if (val > maxval) {
|
||||||
maxval = val;
|
maxval = val;
|
||||||
argmax = col;
|
argmax = col;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue