From ba5592c65336159e7de379f1e39a1354ce287b74 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Mon, 29 Jan 2024 10:31:41 +0800 Subject: [PATCH] CUDA POOL2D --- ggml-cuda.cu | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++ ggml.c | 16 ---------- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index db136ae9c..4fb5ce784 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -530,6 +530,7 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16 #define CUDA_PAD_BLOCK_SIZE 256 #define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256 +#define CUDA_POOL2D_BLOCK_SIZE 256 #define CUDA_Q8_0_NE_ALIGN 2048 @@ -6033,6 +6034,48 @@ static __global__ void im2col_kernel( } } +template +static __global__ void pool2d_nchw_kernel( + const int ih, const int iw, const int oh, const int ow, + const int kh, const int kw, const int sh, const int sw, + const int ph, const int pw, + const Ti* src, To* dst, const enum ggml_op_pool op) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + const int I_HW = ih * iw; + const int O_HW = oh * ow; + const int nc = idx / (oh * ow); + const int cur_oh = idx % (oh * ow) / ow; + const int cur_ow = idx % (oh * ow) % ow; + const Ti* i_ptr = src + nc * I_HW; + To* o_ptr = dst + nc * O_HW; + const int start_h = cur_oh * sh - ph; + const int bh = max(0, start_h); + const int eh = min(ih, start_h + kh); + const int start_w = ow * sw - pw; + const int bw = max(0, start_w); + const int ew = min(iw, start_w + kw); + const To scale = 1. / ((eh - bh) * (ew - bw)); + To res = 0; + switch(op){ + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; + } + for(int i = bh; i < eh; i += 1){ + for(int j = bw; j < ew; j += 1){ + #if __CUDA_ARCH__ >= 350 + Ti cur = __ldg(i_ptr + i * iw + j); + #else + Ti cur = i_ptr[i * iw + j]; + #endif + switch(op){ + case GGML_OP_POOL_AVG: res += cur * scale; break; + case GGML_OP_POOL_MAX: res = max(res, (To)cur); break; + } + } + } + o_ptr[cur_oh * ow + cur_ow] = res; +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -8670,6 +8713,38 @@ static void ggml_cuda_op_alibi( (void) src1_dd; } +static void ggml_cuda_op_pool2d( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = static_cast(opts[0]); + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + const int64_t IC = src0->ne[2]; + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int parallel_elements = N * OC * OH * OW; + const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; + dim3 block_nums(num_blocks); + pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op); + + (void) src0; + (void) src0_dd; +} + + static void ggml_cuda_op_im2col( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { @@ -10084,6 +10159,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } +static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d); +} + static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } @@ -10265,6 +10344,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_IM2COL: func = ggml_cuda_im2col; break; + case GGML_OP_POOL_2D: + func = ggml_cuda_pool2d; + break; case GGML_OP_SUM_ROWS: func = ggml_cuda_sum_rows; break; diff --git a/ggml.c b/ggml.c index 1c74d80e3..ccb1bfb5e 100644 --- a/ggml.c +++ b/ggml.c @@ -5580,21 +5580,6 @@ struct ggml_tensor * ggml_pool_2d( } struct ggml_tensor * result; -#if defined(GGML_USE_CUBLAS) - if(!(op == GGML_OP_POOL_AVG)) { - GGML_ASSERT(false); - } - - const int64_t ne[4] = {k0, k1, 1, a->ne[2]}; - struct ggml_tensor * b = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct ggml_tensor * im2col = ggml_im2col(ctx, b, new_a, - s0, s1, p0, p1, 1, 1, true, GGML_TYPE_F32); // [N * IC, OH, OW, KH * KW] - - result = ggml_sum_rows(ctx, im2col); - result = ggml_scale(ctx, result, 1. / (k0 * k1)); - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[2], a->ne[3]); -#else const int64_t ne[3] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), @@ -5608,7 +5593,6 @@ struct ggml_tensor * ggml_pool_2d( result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; -#endif return result; }