Releasing MTLFunction references after Metal pipeline construction

This commit is contained in:
Paul Tsochantaris 2024-01-27 22:43:01 +00:00
parent a1d6df129b
commit 266349ae89

View file

@ -24,13 +24,6 @@
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define GGML_METAL_MAX_KERNELS 256
struct ggml_metal_kernel {
id<MTLFunction> function;
id<MTLComputePipelineState> pipeline;
};
enum ggml_metal_kernel_type { enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_ROW, GGML_METAL_KERNEL_TYPE_ADD_ROW,
@ -163,7 +156,7 @@ struct ggml_metal_context {
dispatch_queue_t d_queue; dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS]; id<MTLComputePipelineState> pipelines[GGML_METAL_KERNEL_TYPE_COUNT];
bool support_simdgroup_reduction; bool support_simdgroup_reduction;
bool support_simdgroup_mm; bool support_simdgroup_mm;
@ -367,9 +360,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
{ {
NSError * error = nil; NSError * error = nil;
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
ctx->kernels[i].function = nil; ctx->pipelines[i] = nil;
ctx->kernels[i].pipeline = nil;
} }
/* /*
@ -379,9 +371,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
*/ */
#define GGML_METAL_ADD_KERNEL(e, name, supported) \ #define GGML_METAL_ADD_KERNEL(e, name, supported) \
if (supported) { \ if (supported) { \
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ id<MTLFunction> metal_function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->pipelines[e] = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ [metal_function release]; \
if (error) { \ if (error) { \
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \ return NULL; \
@ -518,14 +510,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
static void ggml_metal_free(struct ggml_metal_context * ctx) { static void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
if (ctx->kernels[i].pipeline) { [ctx->pipelines[i] release];
[ctx->kernels[i].pipeline release];
}
if (ctx->kernels[i].function) {
[ctx->kernels[i].function release];
}
} }
[ctx->library release]; [ctx->library release];
@ -799,7 +785,7 @@ static bool ggml_metal_graph_compute(
{ {
const int64_t nb = ne00; const int64_t nb = ne00;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CONCAT];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -855,18 +841,18 @@ static bool ggml_metal_graph_compute(
nb = ne00 / 4; nb = ne00 / 4;
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_ADD: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD_ROW]; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; case GGML_OP_MUL: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_ROW]; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; case GGML_OP_DIV: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIV_ROW]; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
bcast_row = true; bcast_row = true;
} else { } else {
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; case GGML_OP_ADD: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD]; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; case GGML_OP_MUL: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL]; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; case GGML_OP_DIV: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIV]; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
} }
@ -933,7 +919,7 @@ static bool ggml_metal_graph_compute(
// not sure how to avoid this // not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel // TODO: make a simpler cpy_bytes kernel
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F32];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -960,7 +946,7 @@ static bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} }
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ADD];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1008,9 +994,9 @@ static bool ggml_metal_graph_compute(
if (n % 4 == 0) { if (n % 4 == 0) {
n /= 4; n /= 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SCALE_4];
} else { } else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SCALE];
} }
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -1024,7 +1010,7 @@ static bool ggml_metal_graph_compute(
switch (ggml_get_unary_op(gf->nodes[i])) { switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
{ {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_TANH];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1036,7 +1022,7 @@ static bool ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_RELU];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1048,7 +1034,7 @@ static bool ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GELU];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1061,7 +1047,7 @@ static bool ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
{ {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GELU_QUICK];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1074,7 +1060,7 @@ static bool ggml_metal_graph_compute(
} break; } break;
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
{ {
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SILU];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1095,7 +1081,7 @@ static bool ggml_metal_graph_compute(
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SQR];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1109,7 +1095,7 @@ static bool ggml_metal_graph_compute(
{ {
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SUM_ROWS];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1151,12 +1137,12 @@ static bool ggml_metal_graph_compute(
while (nth < ne00/4 && nth < 256) { while (nth < ne00/4 && nth < 256) {
nth *= 2; nth *= 2;
} }
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4];
} else { } else {
while (nth < ne00 && nth < 1024) { while (nth < ne00 && nth < 1024) {
nth *= 2; nth *= 2;
} }
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_SOFT_MAX];
} }
const float scale = ((float *) dst->op_params)[0]; const float scale = ((float *) dst->op_params)[0];
@ -1184,9 +1170,9 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
if (ne00%8 == 0) { if (ne00%8 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8];
} else { } else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF];
} }
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -1253,20 +1239,20 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ]; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ]; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ]; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ]; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ]; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ]; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ]; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ]; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ]; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ]; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ]; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32]; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
} }
@ -1301,7 +1287,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32];
nrows = 4; nrows = 4;
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
@ -1310,16 +1296,16 @@ static bool ggml_metal_graph_compute(
nth1 = 1; nth1 = 1;
if (src1t == GGML_TYPE_F32) { if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) { if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4];
nrows = ne11; nrows = ne11;
} else { } else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32];
nrows = 4; nrows = 4;
} }
} else { } else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16];
nrows = 4; nrows = 4;
} }
} break; } break;
@ -1327,73 +1313,73 @@ static bool ggml_metal_graph_compute(
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32];
} break; } break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32];
} break; } break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32];
} break; } break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32];
} break; } break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32];
} break; } break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32];
} break; } break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32];
} break; } break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
{ {
nth0 = 4; //1; nth0 = 4; //1;
nth1 = 8; //32; nth1 = 8; //32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32];
} break; } break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32];
} break; } break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32];
} break; } break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
{ {
nth0 = 4; nth0 = 4;
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32];
} break; } break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
{ {
nth0 = 4; nth0 = 4;
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32];
} break; } break;
default: default:
{ {
@ -1517,20 +1503,20 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src2->type) { switch (src2->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ]; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ]; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ]; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ]; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ]; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ]; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ]; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ]; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ]; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ]; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ]; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32]; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
} }
@ -1581,86 +1567,86 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32];
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
nth0 = 32; nth0 = 32;
nth1 = 1; nth1 = 1;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32];
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32];
} break; } break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32];
} break; } break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32];
} break; } break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32];
} break; } break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
{ {
nth0 = 8; nth0 = 8;
nth1 = 8; nth1 = 8;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32];
} break; } break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32];
} break; } break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32];
} break; } break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
{ {
nth0 = 4; //1; nth0 = 4; //1;
nth1 = 8; //32; nth1 = 8; //32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32];
} break; } break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32];
} break; } break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
{ {
nth0 = 2; nth0 = 2;
nth1 = 32; nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32];
} break; } break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
{ {
nth0 = 4; nth0 = 4;
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32];
} break; } break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
{ {
nth0 = 4; nth0 = 4;
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32];
} break; } break;
default: default:
{ {
@ -1746,21 +1732,21 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ]; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ]; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ]; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ]; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ]; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ]; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ]; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ]; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ]; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; case GGML_TYPE_Q5_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ]; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; case GGML_TYPE_Q6_K: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ]; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS]; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ]; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ]; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
} }
@ -1792,7 +1778,7 @@ static bool ggml_metal_graph_compute(
nth *= 2; nth *= 2;
} }
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_RMS_NORM];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1823,7 +1809,7 @@ static bool ggml_metal_graph_compute(
// nth *= 2; // nth *= 2;
//} //}
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_GROUP_NORM];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1847,7 +1833,7 @@ static bool ggml_metal_graph_compute(
const int nth = MIN(256, ne00); const int nth = MIN(256, ne00);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_NORM];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1876,7 +1862,7 @@ static bool ggml_metal_graph_compute(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ALIBI_F32];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1926,8 +1912,8 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ROPE_F32]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ROPE_F16]; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
}; };
@ -1998,7 +1984,7 @@ static bool ggml_metal_graph_compute(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_IM2COL_F16]; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
}; };
@ -2025,7 +2011,7 @@ static bool ggml_metal_graph_compute(
const int sf = dst->op_params[0]; const int sf = dst->op_params[0];
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_UPSCALE_F32];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -2056,7 +2042,7 @@ static bool ggml_metal_graph_compute(
{ {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_PAD_F32];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -2094,8 +2080,8 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (order) { switch (order) {
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; case GGML_SORT_ASC: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC]; break;
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; case GGML_SORT_DESC: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC]; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
}; };
@ -2113,7 +2099,7 @@ static bool ggml_metal_graph_compute(
float slope; float slope;
memcpy(&slope, dst->op_params, sizeof(float)); memcpy(&slope, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -2140,21 +2126,21 @@ static bool ggml_metal_graph_compute(
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F16]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_F32]; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0]; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0]; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1]; break;
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; //case GGML_TYPE_Q5_0: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0]; break;
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; //case GGML_TYPE_Q5_1: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1]; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F16_F16]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->pipelines[GGML_METAL_KERNEL_TYPE_CPY_F16_F32]; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;