fix bug in pool2d kernel

This commit is contained in:
zhangjidong 2024-01-30 11:06:37 +08:00
parent 1556d4ca17
commit 379f89fbbe

View file

@ -6045,15 +6045,15 @@ static __global__ void pool2d_nchw_kernel(
return;
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / (oh * ow);
const int cur_oh = idx % (oh * ow) / ow;
const int cur_ow = idx % (oh * ow) % ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = max(0, start_h);
const int eh = min(ih, start_h + kh);
const int start_w = ow * sw - pw;
const int start_w = cur_ow * sw - pw;
const int bw = max(0, start_w);
const int ew = min(iw, start_w + kw);
const To scale = 1. / ((eh - bh) * (ew - bw));