diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3c9863dae..20e3b5efa 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6045,15 +6045,15 @@ static __global__ void pool2d_nchw_kernel( return; const int I_HW = ih * iw; const int O_HW = oh * ow; - const int nc = idx / (oh * ow); - const int cur_oh = idx % (oh * ow) / ow; - const int cur_ow = idx % (oh * ow) % ow; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / ow; + const int cur_ow = idx % O_HW % ow; const Ti* i_ptr = src + nc * I_HW; To* o_ptr = dst + nc * O_HW; const int start_h = cur_oh * sh - ph; const int bh = max(0, start_h); const int eh = min(ih, start_h + kh); - const int start_w = ow * sw - pw; + const int start_w = cur_ow * sw - pw; const int bw = max(0, start_w); const int ew = min(iw, start_w + kw); const To scale = 1. / ((eh - bh) * (ew - bw));