fix pool2d_kernel

nits
This commit is contained in:
zhangjidong 2024-01-30 10:28:18 +08:00
parent 41a34cb3de
commit 1556d4ca17

View file

@ -6038,9 +6038,11 @@ template <typename Ti, typename To>
static __global__ void pool2d_nchw_kernel( static __global__ void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow, const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw, const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) { const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
if(idx >= parallel_elements)
return;
const int I_HW = ih * iw; const int I_HW = ih * iw;
const int O_HW = oh * ow; const int O_HW = oh * ow;
const int nc = idx / (oh * ow); const int nc = idx / (oh * ow);
@ -8737,7 +8739,7 @@ static void ggml_cuda_op_pool2d(
const int parallel_elements = N * OC * OH * OW; const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
dim3 block_nums(num_blocks); dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op); pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
(void) src0; (void) src0;
(void) src0_dd; (void) src0_dd;