apply review

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
Junhee Yoo 2024-10-23 17:00:17 +09:00
parent bb9949b3f6
commit 746e79e9a5
2 changed files with 31 additions and 25 deletions

View file

@ -854,7 +854,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_POOL_1D: case GGML_OP_POOL_1D:
return false; return false;
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
return true;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
@ -2554,6 +2553,8 @@ static void ggml_metal_encode_node(
} break; } break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@ -3034,9 +3035,10 @@ static void ggml_metal_encode_node(
} break; } break;
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
const int32_t* opts = dst->op_params; const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0]; enum ggml_op_pool op = opts[0];
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;

View file

@ -6479,8 +6479,8 @@ kernel void kernel_pool_2d_max_f32(
const int cur_oh = idx % O_HW / OW; const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW; const int cur_ow = idx % O_HW % OW;
device const float* i_ptr = src0 + nc * I_HW; device const float * i_ptr = src0 + nc * I_HW;
device float* o_ptr = dst + nc * O_HW; device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1; const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h); const int bh = MAX(0, start_h);
@ -6488,6 +6488,7 @@ kernel void kernel_pool_2d_max_f32(
const int start_w = cur_ow * s0 - p0; const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w); const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0); const int ew = MIN(IW, start_w + k0);
float res = -INFINITY; float res = -INFINITY;
for (int i = bh; i < eh; i += 1) { for (int i = bh; i < eh; i += 1) {
@ -6495,23 +6496,24 @@ kernel void kernel_pool_2d_max_f32(
res = MAX(res, i_ptr[i * IW + j]); res = MAX(res, i_ptr[i * IW + j]);
} }
} }
o_ptr[cur_oh * OW + cur_ow] = res; o_ptr[cur_oh * OW + cur_ow] = res;
} }
kernel void kernel_pool_2d_avg_f32( kernel void kernel_pool_2d_avg_f32(
device const float* src0, device const float * src0,
device float* dst, device float * dst,
constant int32_t& k0, constant int32_t & k0,
constant int32_t& k1, constant int32_t & k1,
constant int32_t& s0, constant int32_t & s0,
constant int32_t& s1, constant int32_t & s1,
constant int32_t& p0, constant int32_t & p0,
constant int32_t& p1, constant int32_t & p1,
constant int64_t& IH, constant int64_t & IH,
constant int64_t& IW, constant int64_t & IW,
constant int64_t& OH, constant int64_t & OH,
constant int64_t& OW, constant int64_t & OW,
constant int64_t& parallel_elements, constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) { uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) { if (gid >= parallel_elements) {
@ -6525,8 +6527,8 @@ kernel void kernel_pool_2d_avg_f32(
const int cur_oh = idx % O_HW / OW; const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW; const int cur_ow = idx % O_HW % OW;
device const float* i_ptr = src0 + nc * I_HW; device const float * i_ptr = src0 + nc * I_HW;
device float* o_ptr = dst + nc * O_HW; device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1; const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h); const int bh = MAX(0, start_h);
@ -6536,6 +6538,7 @@ kernel void kernel_pool_2d_avg_f32(
const int ew = MIN(IW, start_w + k0); const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw)); // const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1); const float scale = 1. / (k0 * k1);
float res = 0; float res = 0;
for (int i = bh; i < eh; i += 1) { for (int i = bh; i < eh; i += 1) {
@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32(
res += cur * scale; res += cur * scale;
} }
} }
o_ptr[cur_oh * OW + cur_ow] = res; o_ptr[cur_oh * OW + cur_ow] = res;
} }