sync : ggml (SD ops, tests, kernels) (#4444)

* sync : ggml (SD ops, tests, kernels)

ggml-ci

* cuda : restore im2col

ggml-ci

* metal : fix accuracy of dequantization kernels

ggml-ci

* cuda : restore correct im2col

ggml-ci

* metal : try to fix moe test by reducing expert size

ggml-ci

* cuda : fix bin bcast when src1 and dst have different types

ggml-ci

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-12-13 21:54:54 +02:00 committed by GitHub
parent 70f806b821
commit 4d98d9a656
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1334 additions and 130 deletions

View file

@ -66,9 +66,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(div_row);
GGML_METAL_DECL_KERNEL(scale);
GGML_METAL_DECL_KERNEL(scale_4);
GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(tanh);
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(gelu_quick);
GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
@ -86,6 +88,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@ -145,8 +148,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(upscale_f32);
GGML_METAL_DECL_KERNEL(pad_f32);
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@ -334,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(div_row);
GGML_METAL_ADD_KERNEL(scale);
GGML_METAL_ADD_KERNEL(scale_4);
GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(tanh);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(gelu_quick);
GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
@ -354,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@ -415,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(upscale_f32);
GGML_METAL_ADD_KERNEL(pad_f32);
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@ -450,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(div_row);
GGML_METAL_DEL_KERNEL(scale);
GGML_METAL_DEL_KERNEL(scale_4);
GGML_METAL_DEL_KERNEL(silu);
GGML_METAL_DEL_KERNEL(tanh);
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(gelu_quick);
GGML_METAL_DEL_KERNEL(silu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
@ -470,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@ -531,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(upscale_f32);
GGML_METAL_DEL_KERNEL(pad_f32);
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@ -843,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
return true;
default:
return false;
@ -853,11 +873,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
@ -865,11 +885,15 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_SUM_ROWS:
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_NORM:
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
@ -902,8 +926,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
};
}
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
return op->ne[3] == 1;
}
default:
return false;
@ -979,7 +1004,10 @@ void ggml_metal_graph_compute(
} break;
}
GGML_ASSERT(ggml_metal_supports_op(dst));
if (!ggml_metal_supports_op(dst)) {
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
GGML_ASSERT(!"unsupported op");
}
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
@ -1076,6 +1104,8 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
case GGML_OP_DIV:
{
const size_t offs = 0;
bool bcast_row = false;
int64_t nb = ne00;
@ -1134,7 +1164,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
if (bcast_row) {
const int64_t n = ggml_nelements(dst)/4;
@ -1146,6 +1177,86 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_ACC:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(dstt == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
const size_t offs = ((int32_t *) dst->op_params)[3];
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
const int nth = MIN(1024, ne00);
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
[encoder setComputePipelineState:ctx->pipeline_add];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SCALE:
{
GGML_ASSERT(ggml_is_contiguous(src0));
@ -1169,16 +1280,15 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_TANH:
{
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setComputePipelineState:ctx->pipeline_tanh];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_RELU:
{
@ -1199,6 +1309,28 @@ void ggml_metal_graph_compute(
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_SILU:
{
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
@ -1837,6 +1969,38 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_GROUP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
//float eps;
//memcpy(&eps, dst->op_params, sizeof(float));
const float eps = 1e-6f; // TODO: temporarily hardcoded
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
int nth = 32; // SIMD width
//while (nth < ne00/4 && nth < 1024) {
// nth *= 2;
//}
[encoder setComputePipelineState:ctx->pipeline_group_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_NORM:
{
float eps;
@ -2006,6 +2170,65 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_UPSCALE:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
const int sf = dst->op_params[0];
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_PAD:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
@ -2027,6 +2250,22 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
} break;
case GGML_OP_LEAKY_RELU:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
float slope;
memcpy(&slope, dst->op_params, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT: