fix bug in pool2d kernel
This commit is contained in:
parent
1556d4ca17
commit
379f89fbbe
1 changed files with 4 additions and 4 deletions
|
@ -6045,15 +6045,15 @@ static __global__ void pool2d_nchw_kernel(
|
|||
return;
|
||||
const int I_HW = ih * iw;
|
||||
const int O_HW = oh * ow;
|
||||
const int nc = idx / (oh * ow);
|
||||
const int cur_oh = idx % (oh * ow) / ow;
|
||||
const int cur_ow = idx % (oh * ow) % ow;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / ow;
|
||||
const int cur_ow = idx % O_HW % ow;
|
||||
const Ti* i_ptr = src + nc * I_HW;
|
||||
To* o_ptr = dst + nc * O_HW;
|
||||
const int start_h = cur_oh * sh - ph;
|
||||
const int bh = max(0, start_h);
|
||||
const int eh = min(ih, start_h + kh);
|
||||
const int start_w = ow * sw - pw;
|
||||
const int start_w = cur_ow * sw - pw;
|
||||
const int bw = max(0, start_w);
|
||||
const int ew = min(iw, start_w + kw);
|
||||
const To scale = 1. / ((eh - bh) * (ew - bw));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue