apply suggestions
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
e81462dda1
commit
0084847991
2 changed files with 19 additions and 19 deletions
|
@ -274,8 +274,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
|
@ -722,8 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true);
|
||||
}
|
||||
|
||||
[metal_library release];
|
||||
|
@ -3044,9 +3044,9 @@ static void ggml_metal_encode_node(
|
|||
case GGML_TYPE_F32: {
|
||||
switch(op) {
|
||||
case GGML_OP_POOL_AVG:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
||||
case GGML_OP_POOL_MAX:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
} break;
|
||||
|
|
|
@ -6453,19 +6453,19 @@ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t
|
|||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
|
||||
kernel void kernel_max_pool_2d_f32(
|
||||
device const float* src0,
|
||||
device float* dst,
|
||||
constant int32_t& k0,
|
||||
constant int32_t& k1,
|
||||
constant int32_t& s0,
|
||||
constant int32_t& s1,
|
||||
constant int32_t& p0,
|
||||
constant int32_t& p1,
|
||||
constant int64_t& IH,
|
||||
constant int64_t& IW,
|
||||
constant int64_t& OH,
|
||||
constant int64_t& OW,
|
||||
constant int64_t& parallel_elements,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue