From 746e79e9a5d375d2efaecfad3c2ca8aeaf2617bb Mon Sep 17 00:00:00 2001 From: Junhee Yoo Date: Wed, 23 Oct 2024 17:00:17 +0900 Subject: [PATCH] apply review Signed-off-by: Junhee Yoo --- ggml/src/ggml-metal.m | 10 +++++---- ggml/src/ggml-metal.metal | 46 +++++++++++++++++++++------------------ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index ed4d8326c..0267b0026 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -854,7 +854,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_POOL_1D: return false; case GGML_OP_POOL_2D: - return true; case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: @@ -2554,6 +2553,8 @@ static void ggml_metal_encode_node( } break; case GGML_OP_IM2COL: { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); @@ -2620,7 +2621,7 @@ static void ggml_metal_encode_node( [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; if (is_gt_mttpt) { - [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; + [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; @@ -3034,9 +3035,10 @@ static void ggml_metal_encode_node( } break; case GGML_OP_POOL_2D: { + GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); - const int32_t* opts = dst->op_params; + const int32_t * opts = dst->op_params; enum ggml_op_pool op = opts[0]; id pipeline = nil; @@ -3063,7 +3065,7 @@ static void ggml_metal_encode_node( const int64_t IH = src0->ne[1]; const int64_t IW = src0->ne[0]; - const int64_t N = dst->ne[3]; + const int64_t N = dst->ne[3]; const int64_t OC = dst->ne[2]; const int64_t OH = dst->ne[1]; const int64_t OW = dst->ne[0]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 1fb05ccea..71b58be1f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6479,15 +6479,16 @@ kernel void kernel_pool_2d_max_f32( const int cur_oh = idx % O_HW / OW; const int cur_ow = idx % O_HW % OW; - device const float* i_ptr = src0 + nc * I_HW; - device float* o_ptr = dst + nc * O_HW; + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; const int start_h = cur_oh * s1 - p1; - const int bh = MAX(0, start_h); + const int bh = MAX(0, start_h); const int eh = MIN(IH, start_h + k1); const int start_w = cur_ow * s0 - p0; - const int bw = MAX(0, start_w); + const int bw = MAX(0, start_w); const int ew = MIN(IW, start_w + k0); + float res = -INFINITY; for (int i = bh; i < eh; i += 1) { @@ -6495,23 +6496,24 @@ kernel void kernel_pool_2d_max_f32( res = MAX(res, i_ptr[i * IW + j]); } } + o_ptr[cur_oh * OW + cur_ow] = res; } kernel void kernel_pool_2d_avg_f32( - device const float* src0, - device float* dst, - constant int32_t& k0, - constant int32_t& k1, - constant int32_t& s0, - constant int32_t& s1, - constant int32_t& p0, - constant int32_t& p1, - constant int64_t& IH, - constant int64_t& IW, - constant int64_t& OH, - constant int64_t& OW, - constant int64_t& parallel_elements, + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, uint gid[[thread_position_in_grid]]) { if (gid >= parallel_elements) { @@ -6525,17 +6527,18 @@ kernel void kernel_pool_2d_avg_f32( const int cur_oh = idx % O_HW / OW; const int cur_ow = idx % O_HW % OW; - device const float* i_ptr = src0 + nc * I_HW; - device float* o_ptr = dst + nc * O_HW; + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; const int start_h = cur_oh * s1 - p1; - const int bh = MAX(0, start_h); + const int bh = MAX(0, start_h); const int eh = MIN(IH, start_h + k1); const int start_w = cur_ow * s0 - p0; - const int bw = MAX(0, start_w); + const int bw = MAX(0, start_w); const int ew = MIN(IW, start_w + k0); // const float scale = 1. / ((eh - bh) * (ew - bw)); const float scale = 1. / (k0 * k1); + float res = 0; for (int i = bh; i < eh; i += 1) { @@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32( res += cur * scale; } } + o_ptr[cur_oh * OW + cur_ow] = res; }