add pool_2d
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
f010b77a37
commit
b4d3c16493
2 changed files with 158 additions and 1 deletions
|
@ -272,6 +272,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
|
@ -716,6 +718,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true);
|
||||
}
|
||||
|
||||
[metal_library release];
|
||||
|
@ -844,8 +848,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
return false;
|
||||
case GGML_OP_POOL_2D:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
|
@ -3001,6 +3006,63 @@ static void ggml_metal_encode_node(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
||||
|
||||
const int32_t* opts = dst->op_params;
|
||||
enum ggml_op_pool op = opts[0];
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32: {
|
||||
switch(op) {
|
||||
case GGML_OP_POOL_AVG:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break;
|
||||
case GGML_OP_POOL_MAX:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
} break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
const int32_t k0 = opts[1];
|
||||
const int32_t k1 = opts[2];
|
||||
const int32_t s0 = opts[3];
|
||||
const int32_t s1 = opts[4];
|
||||
const int32_t p0 = opts[5];
|
||||
const int32_t p1 = opts[6];
|
||||
|
||||
const int64_t IH = src0->ne[1];
|
||||
const int64_t IW = src0->ne[0];
|
||||
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t OC = dst->ne[2];
|
||||
const int64_t OH = dst->ne[1];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
||||
const int64_t parallel_elements = N * OC * OH * OW;
|
||||
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
||||
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
||||
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
||||
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
||||
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
|
|
@ -6372,3 +6372,98 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
|
|||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
|
||||
kernel void kernel_max_pool_2d_f32(
|
||||
device const float* src0,
|
||||
device float* dst,
|
||||
constant int32_t& k0,
|
||||
constant int32_t& k1,
|
||||
constant int32_t& s0,
|
||||
constant int32_t& s1,
|
||||
constant int32_t& p0,
|
||||
constant int32_t& p1,
|
||||
constant int64_t& IH,
|
||||
constant int64_t& IW,
|
||||
constant int64_t& OH,
|
||||
constant int64_t& OW,
|
||||
constant int64_t& parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float* i_ptr = src0 + nc * I_HW;
|
||||
device float* o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * IW + j]);
|
||||
}
|
||||
}
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_avg_pool_2d_f32(
|
||||
device const float* src0,
|
||||
device float* dst,
|
||||
constant int32_t& k0,
|
||||
constant int32_t& k1,
|
||||
constant int32_t& s0,
|
||||
constant int32_t& s1,
|
||||
constant int32_t& p0,
|
||||
constant int32_t& p1,
|
||||
constant int64_t& IH,
|
||||
constant int64_t& IW,
|
||||
constant int64_t& OH,
|
||||
constant int64_t& OW,
|
||||
constant int64_t& parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float* i_ptr = src0 + nc * I_HW;
|
||||
device float* o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (k0 * k1);
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue