apply review
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
bb9949b3f6
commit
746e79e9a5
2 changed files with 31 additions and 25 deletions
|
@ -854,7 +854,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_POOL_2D:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
|
@ -2554,6 +2553,8 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
@ -2620,7 +2621,7 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||
|
||||
if (is_gt_mttpt) {
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
|
@ -3034,9 +3035,10 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
||||
|
||||
const int32_t* opts = dst->op_params;
|
||||
const int32_t * opts = dst->op_params;
|
||||
enum ggml_op_pool op = opts[0];
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
@ -3063,7 +3065,7 @@ static void ggml_metal_encode_node(
|
|||
const int64_t IH = src0->ne[1];
|
||||
const int64_t IW = src0->ne[0];
|
||||
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t OC = dst->ne[2];
|
||||
const int64_t OH = dst->ne[1];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
|
|
@ -6479,15 +6479,16 @@ kernel void kernel_pool_2d_max_f32(
|
|||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float* i_ptr = src0 + nc * I_HW;
|
||||
device float* o_ptr = dst + nc * O_HW;
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
|
@ -6495,23 +6496,24 @@ kernel void kernel_pool_2d_max_f32(
|
|||
res = MAX(res, i_ptr[i * IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
device const float* src0,
|
||||
device float* dst,
|
||||
constant int32_t& k0,
|
||||
constant int32_t& k1,
|
||||
constant int32_t& s0,
|
||||
constant int32_t& s1,
|
||||
constant int32_t& p0,
|
||||
constant int32_t& p1,
|
||||
constant int64_t& IH,
|
||||
constant int64_t& IW,
|
||||
constant int64_t& OH,
|
||||
constant int64_t& OW,
|
||||
constant int64_t& parallel_elements,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
|
@ -6525,17 +6527,18 @@ kernel void kernel_pool_2d_avg_f32(
|
|||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float* i_ptr = src0 + nc * I_HW;
|
||||
device float* o_ptr = dst + nc * O_HW;
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (k0 * k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
|
@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32(
|
|||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue