diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 172a0f925..6d854bb07 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -272,6 +272,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, + GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -716,6 +718,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true); } [metal_library release]; @@ -844,8 +848,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: return false; + case GGML_OP_POOL_2D: + return true; case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: @@ -3001,6 +3006,63 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); + + const int32_t* opts = dst->op_params; + enum ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case GGML_TYPE_F32: { + switch(op) { + case GGML_OP_POOL_AVG: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break; + case GGML_OP_POOL_MAX: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + } + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; + [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; + [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; + [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; + [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b2000323..3cbbeb9c2 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6372,3 +6372,98 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_max_pool_2d_f32( + device const float* src0, + device float* dst, + constant int32_t& k0, + constant int32_t& k1, + constant int32_t& s0, + constant int32_t& s1, + constant int32_t& p0, + constant int32_t& p1, + constant int64_t& IH, + constant int64_t& IW, + constant int64_t& OH, + constant int64_t& OW, + constant int64_t& parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float* i_ptr = src0 + nc * I_HW; + device float* o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } + o_ptr[cur_oh * OW + cur_ow] = res; +} + +kernel void kernel_avg_pool_2d_f32( + device const float* src0, + device float* dst, + constant int32_t& k0, + constant int32_t& k1, + constant int32_t& s0, + constant int32_t& s1, + constant int32_t& p0, + constant int32_t& p1, + constant int64_t& IH, + constant int64_t& IW, + constant int64_t& OH, + constant int64_t& OW, + constant int64_t& parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float* i_ptr = src0 + nc * I_HW; + device float* o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } + o_ptr[cur_oh * OW + cur_ow] = res; +}