diff --git a/ggml-metal.m b/ggml-metal.m index af5a234b6..f05d9d331 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -24,6 +24,10 @@ #define UNUSED(x) (void)(x) +struct ggml_metal_kernel { + id pipeline; +}; + enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_ROW, @@ -152,11 +156,10 @@ struct ggml_metal_context { id device; id queue; - id library; - + dispatch_queue_t d_queue; - id pipelines[GGML_METAL_KERNEL_TYPE_COUNT]; + struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -239,6 +242,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->queue = [ctx->device newCommandQueue]; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + id metal_library; + // load library { NSBundle * bundle = nil; @@ -253,7 +258,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -295,7 +300,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //[options setFastMathEnabled:false]; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -361,7 +366,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { NSError * error = nil; for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - ctx->pipelines[i] = nil; + ctx->kernels[i].pipeline = nil; } /* @@ -371,11 +376,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { */ #define GGML_METAL_ADD_KERNEL(e, name, supported) \ if (supported) { \ - id metal_function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipelines[e] = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ + struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ + id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + [metal_library release]; \ return NULL; \ } \ } else { \ @@ -504,6 +511,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } + [metal_library release]; return ctx; } @@ -511,10 +519,9 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - [ctx->pipelines[i] release]; + [ctx->kernels[i].pipeline release]; } - [ctx->library release]; [ctx->queue release]; [ctx->device release]; @@ -785,7 +792,7 @@ static bool ggml_metal_graph_compute( { const int64_t nb = ne00; - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CONCAT]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -841,18 +848,18 @@ static bool ggml_metal_graph_compute( nb = ne00 / 4; switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD_ROW]; break; - case GGML_OP_MUL: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_ROW]; break; - case GGML_OP_DIV: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIV_ROW]; break; + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; default: GGML_ASSERT(false); } bcast_row = true; } else { switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD]; break; - case GGML_OP_MUL: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL]; break; - case GGML_OP_DIV: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIV]; break; + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; default: GGML_ASSERT(false); } } @@ -919,7 +926,7 @@ static bool ggml_metal_graph_compute( // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel - const id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F32]; + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -946,7 +953,7 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } - const id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD]; + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -994,9 +1001,9 @@ static bool ggml_metal_graph_compute( if (n % 4 == 0) { n /= 4; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SCALE_4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; } else { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SCALE]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; } [encoder setComputePipelineState:pipeline]; @@ -1010,7 +1017,7 @@ static bool ggml_metal_graph_compute( switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_TANH: { - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_TANH]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1022,7 +1029,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_RELU: { - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_RELU]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1034,7 +1041,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GELU]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1047,7 +1054,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU_QUICK: { - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GELU_QUICK]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1060,7 +1067,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_SILU: { - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SILU]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1081,7 +1088,7 @@ static bool ggml_metal_graph_compute( { GGML_ASSERT(ggml_is_contiguous(src0)); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SQR]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1095,7 +1102,7 @@ static bool ggml_metal_graph_compute( { GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SUM_ROWS]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1137,12 +1144,12 @@ static bool ggml_metal_graph_compute( while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SOFT_MAX]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; } const float scale = ((float *) dst->op_params)[0]; @@ -1170,9 +1177,9 @@ static bool ggml_metal_graph_compute( id pipeline = nil; if (ne00%8 == 0) { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; } else { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } [encoder setComputePipelineState:pipeline]; @@ -1239,20 +1246,20 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ]; break; - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ]; break; - case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ]; break; - case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ]; break; - case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ]; break; - case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ]; break; - case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ]; break; - case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ]; break; - case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ]; break; - case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ]; break; - case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ]; break; - case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ]; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32]; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } @@ -1287,7 +1294,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; nrows = 4; } break; case GGML_TYPE_F16: @@ -1296,16 +1303,16 @@ static bool ggml_metal_graph_compute( nth1 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; nrows = ne11; } else { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; nrows = 4; } } else { - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; nrows = 4; } } break; @@ -1313,73 +1320,73 @@ static bool ggml_metal_graph_compute( { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { nth0 = 4; nth1 = 16; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { nth0 = 4; nth1 = 16; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; default: { @@ -1503,20 +1510,20 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (src2->type) { - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ]; break; - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ]; break; - case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ]; break; - case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ]; break; - case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ]; break; - case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ]; break; - case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ]; break; - case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ]; break; - case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ]; break; - case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ]; break; - case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ]; break; - case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ]; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32]; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } @@ -1567,86 +1574,86 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); nth0 = 32; nth1 = 1; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { nth0 = 4; nth1 = 16; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { nth0 = 4; nth1 = 16; - pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; default: { @@ -1732,21 +1739,21 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ]; break; - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ]; break; - case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ]; break; - case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ]; break; - case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ]; break; - case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ]; break; - case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ]; break; - case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ]; break; - case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ]; break; - case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ]; break; - case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ]; break; - case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ]; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS]; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ]; break; - case GGML_TYPE_I32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } @@ -1778,7 +1785,7 @@ static bool ggml_metal_graph_compute( nth *= 2; } - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_RMS_NORM]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1809,7 +1816,7 @@ static bool ggml_metal_graph_compute( // nth *= 2; //} - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GROUP_NORM]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1833,7 +1840,7 @@ static bool ggml_metal_graph_compute( const int nth = MIN(256, ne00); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_NORM]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1862,7 +1869,7 @@ static bool ggml_metal_graph_compute( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ALIBI_F32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1912,8 +1919,8 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ROPE_F32]; break; - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ROPE_F16]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; default: GGML_ASSERT(false); }; @@ -1984,7 +1991,7 @@ static bool ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_IM2COL_F16]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; default: GGML_ASSERT(false); }; @@ -2011,7 +2018,7 @@ static bool ggml_metal_graph_compute( const int sf = dst->op_params[0]; - const id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_UPSCALE_F32]; + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -2042,7 +2049,7 @@ static bool ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F32); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_PAD_F32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -2080,8 +2087,8 @@ static bool ggml_metal_graph_compute( id pipeline = nil; switch (order) { - case GGML_SORT_ASC: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC]; break; - case GGML_SORT_DESC: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC]; break; + case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; default: GGML_ASSERT(false); }; @@ -2099,7 +2106,7 @@ static bool ggml_metal_graph_compute( float slope; memcpy(&slope, dst->op_params, sizeof(float)); - id pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -2126,21 +2133,21 @@ static bool ggml_metal_graph_compute( GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F16]; break; - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F32]; break; - case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0]; break; - case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0]; break; - case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1]; break; - //case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0]; break; - //case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; case GGML_TYPE_F16: { switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F16_F16]; break; - case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F16_F32]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break;