From 008484799146274784a5df692fd7b7508805be83 Mon Sep 17 00:00:00 2001 From: Junhee Yoo Date: Wed, 23 Oct 2024 11:18:39 +0900 Subject: [PATCH] apply suggestions Signed-off-by: Junhee Yoo --- ggml/src/ggml-metal.m | 12 ++++++------ ggml/src/ggml-metal.metal | 26 +++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 1ee6b0295..cc6f00475 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -274,8 +274,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, - GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, - GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -722,8 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true); } [metal_library release]; @@ -3044,9 +3044,9 @@ static void ggml_metal_encode_node( case GGML_TYPE_F32: { switch(op) { case GGML_OP_POOL_AVG: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; case GGML_OP_POOL_MAX: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } } break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index da609234d..8b6eba9ba 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6453,19 +6453,19 @@ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; kernel void kernel_max_pool_2d_f32( - device const float* src0, - device float* dst, - constant int32_t& k0, - constant int32_t& k1, - constant int32_t& s0, - constant int32_t& s1, - constant int32_t& p0, - constant int32_t& p1, - constant int64_t& IH, - constant int64_t& IW, - constant int64_t& OH, - constant int64_t& OW, - constant int64_t& parallel_elements, + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, uint gid[[thread_position_in_grid]]) { if (gid >= parallel_elements) {