From 77bb72cd8cf0cc98ba0a3fbc3c7005a26279dc7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 6 Jan 2024 16:33:16 +0200 Subject: [PATCH] metal : normalize encoder:setComputePipelineStatus calls ggml-ci --- ggml-metal.m | 271 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 170 insertions(+), 101 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 9ac640061..d7403fc02 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1051,7 +1051,9 @@ bool ggml_metal_graph_compute( { const int64_t nb = ne00; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1183,11 +1185,8 @@ bool ggml_metal_graph_compute( // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel - // pipeline const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1208,6 +1207,8 @@ bool ggml_metal_graph_compute( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } @@ -1255,13 +1256,16 @@ bool ggml_metal_graph_compute( int64_t n = ggml_nelements(dst); + id pipeline = nil; + if (n % 4 == 0) { n /= 4; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; @@ -1272,7 +1276,9 @@ bool ggml_metal_graph_compute( switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_TANH: { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1282,7 +1288,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_RELU: { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1292,7 +1300,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1303,7 +1313,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU_QUICK: { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1314,7 +1326,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_SILU: { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1333,18 +1347,23 @@ bool ggml_metal_graph_compute( { GGML_ASSERT(ggml_is_contiguous(src0)); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SUM_ROWS: { GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -1378,20 +1397,23 @@ bool ggml_metal_graph_compute( { int nth = 32; // SIMD width + id pipeline = nil; + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; } const float scale = ((float *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1411,11 +1433,15 @@ bool ggml_metal_graph_compute( { const int n_past = ((int32_t *)(dst->op_params))[0]; + id pipeline = nil; + if (ne00%8 == 0) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -1475,21 +1501,26 @@ bool ggml_metal_graph_compute( ne00 % 32 == 0 && ne00 >= 64 && (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1513,12 +1544,14 @@ bool ggml_metal_graph_compute( int nrows = 1; //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + id pipeline = nil; + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; nrows = 4; } break; case GGML_TYPE_F16: @@ -1527,16 +1560,16 @@ bool ggml_metal_graph_compute( nth1 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; nrows = ne11; } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; nrows = 4; } } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; nrows = 4; } } break; @@ -1544,61 +1577,61 @@ bool ggml_metal_graph_compute( { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; default: { @@ -1611,6 +1644,7 @@ bool ggml_metal_graph_compute( GGML_ASSERT(ne00 >= nth0*nth1); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1712,21 +1746,26 @@ bool ggml_metal_graph_compute( if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && ne20 % 32 == 0 && ne20 >= 64 && ne11 > ne11_mm_min) { + + id pipeline = nil; + switch (src2->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1766,79 +1805,81 @@ bool ggml_metal_graph_compute( int nrows = 1; //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + id pipeline = nil; + // use custom matrix x vector kernel switch (src2t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); nth0 = 32; nth1 = 1; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; default: { @@ -1853,6 +1894,7 @@ bool ggml_metal_graph_compute( const int64_t _ne1 = 1; // kernels needs a reference in constant memory + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1915,23 +1957,26 @@ bool ggml_metal_graph_compute( } break; case GGML_OP_GET_ROWS: { + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline]; break; - case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1959,7 +2004,9 @@ bool ggml_metal_graph_compute( nth *= 2; } - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -1988,7 +2035,9 @@ bool ggml_metal_graph_compute( // nth *= 2; //} - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2010,7 +2059,9 @@ bool ggml_metal_graph_compute( const int nth = MIN(256, ne00); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2037,7 +2088,9 @@ bool ggml_metal_graph_compute( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2082,12 +2135,15 @@ bool ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -2150,12 +2206,15 @@ bool ggml_metal_graph_compute( const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + id pipeline = nil; + switch (src0->type) { case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; @@ -2209,7 +2268,9 @@ bool ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -2242,12 +2303,15 @@ bool ggml_metal_graph_compute( enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + id pipeline = nil; + switch (order) { - case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline]; break; - case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline]; break; + case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2261,7 +2325,9 @@ bool ggml_metal_graph_compute( float slope; memcpy(&slope, dst->op_params, sizeof(float)); - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; @@ -2278,33 +2344,36 @@ bool ggml_metal_graph_compute( int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + id pipeline = nil; + switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline]; break; - //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline]; break; - //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; case GGML_TYPE_F16: { switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; default: GGML_ASSERT(false && "not implemented"); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];