From 379f89fbbe4658522795b9ae8aff970fb0696b36 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 11:06:37 +0800 Subject: [PATCH] fix bug in pool2d kernel --- ggml-cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3c9863dae..20e3b5efa 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6045,15 +6045,15 @@ static __global__ void pool2d_nchw_kernel( return; const int I_HW = ih * iw; const int O_HW = oh * ow; - const int nc = idx / (oh * ow); - const int cur_oh = idx % (oh * ow) / ow; - const int cur_ow = idx % (oh * ow) % ow; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / ow; + const int cur_ow = idx % O_HW % ow; const Ti* i_ptr = src + nc * I_HW; To* o_ptr = dst + nc * O_HW; const int start_h = cur_oh * sh - ph; const int bh = max(0, start_h); const int eh = min(ih, start_h + kh); - const int start_w = ow * sw - pw; + const int start_w = cur_ow * sw - pw; const int bw = max(0, start_w); const int ew = min(iw, start_w + kw); const To scale = 1. / ((eh - bh) * (ew - bw));