From 1556d4ca17718417a6dad9bf73939625c2b2e7a0 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 10:28:18 +0800 Subject: [PATCH] fix pool2d_kernel nits --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a8f9265f1..3c9863dae 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6038,9 +6038,11 @@ template static __global__ void pool2d_nchw_kernel( const int ih, const int iw, const int oh, const int ow, const int kh, const int kw, const int sh, const int sw, - const int ph, const int pw, + const int ph, const int pw, const int parallel_elements, const Ti* src, To* dst, const enum ggml_op_pool op) { int idx = threadIdx.x + blockIdx.x * blockDim.x; + if(idx >= parallel_elements) + return; const int I_HW = ih * iw; const int O_HW = oh * ow; const int nc = idx / (oh * ow); @@ -8737,7 +8739,7 @@ static void ggml_cuda_op_pool2d( const int parallel_elements = N * OC * OH * OW; const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; dim3 block_nums(num_blocks); - pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op); + pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op); (void) src0; (void) src0_dd;