apply suggestions

Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
Junhee Yoo 2024-10-23 11:18:39 +09:00
parent e81462dda1
commit 0084847991
2 changed files with 19 additions and 19 deletions

View file

@ -274,8 +274,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32,
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_COUNT
};
@ -722,8 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, avg_pool_2d_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, max_pool_2d_f32, true);
}
[metal_library release];
@ -3044,9 +3044,9 @@ static void ggml_metal_encode_node(
case GGML_TYPE_F32: {
switch(op) {
case GGML_OP_POOL_AVG:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
case GGML_OP_POOL_MAX:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
} break;

View file

@ -6453,19 +6453,19 @@ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
kernel void kernel_max_pool_2d_f32(
device const float* src0,
device float* dst,
constant int32_t& k0,
constant int32_t& k1,
constant int32_t& s0,
constant int32_t& s1,
constant int32_t& p0,
constant int32_t& p1,
constant int64_t& IH,
constant int64_t& IW,
constant int64_t& OH,
constant int64_t& OW,
constant int64_t& parallel_elements,
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {