fix pool2d_kernel
nits
This commit is contained in:
parent
41a34cb3de
commit
1556d4ca17
1 changed files with 4 additions and 2 deletions
|
@ -6038,9 +6038,11 @@ template <typename Ti, typename To>
|
|||
static __global__ void pool2d_nchw_kernel(
|
||||
const int ih, const int iw, const int oh, const int ow,
|
||||
const int kh, const int kw, const int sh, const int sw,
|
||||
const int ph, const int pw,
|
||||
const int ph, const int pw, const int parallel_elements,
|
||||
const Ti* src, To* dst, const enum ggml_op_pool op) {
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if(idx >= parallel_elements)
|
||||
return;
|
||||
const int I_HW = ih * iw;
|
||||
const int O_HW = oh * ow;
|
||||
const int nc = idx / (oh * ow);
|
||||
|
@ -8737,7 +8739,7 @@ static void ggml_cuda_op_pool2d(
|
|||
const int parallel_elements = N * OC * OH * OW;
|
||||
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
|
||||
dim3 block_nums(num_blocks);
|
||||
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op);
|
||||
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue