metal : normalize encoder:setComputePipelineStatus calls

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-01-06 16:33:16 +02:00
parent fef1dbf2eb
commit 77bb72cd8c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1051,7 +1051,9 @@ bool ggml_metal_graph_compute(
{
const int64_t nb = ne00;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1183,11 +1185,8 @@ bool ggml_metal_graph_compute(
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
// pipeline
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1208,6 +1207,8 @@ bool ggml_metal_graph_compute(
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
@ -1255,13 +1256,16 @@ bool ggml_metal_graph_compute(
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n % 4 == 0) {
n /= 4;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
@ -1272,7 +1276,9 @@ bool ggml_metal_graph_compute(
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_TANH:
{
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1282,7 +1288,9 @@ bool ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_RELU:
{
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1292,7 +1300,9 @@ bool ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_GELU:
{
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1303,7 +1313,9 @@ bool ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1314,7 +1326,9 @@ bool ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_SILU:
{
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1333,18 +1347,23 @@ bool ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@ -1378,20 +1397,23 @@ bool ggml_metal_graph_compute(
{
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
} else {
while (nth < ne00 && nth < 1024) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
const float scale = ((float *) dst->op_params)[0];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -1411,11 +1433,15 @@ bool ggml_metal_graph_compute(
{
const int n_past = ((int32_t *)(dst->op_params))[0];
id<MTLComputePipelineState> pipeline = nil;
if (ne00%8 == 0) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@ -1475,21 +1501,26 @@ bool ggml_metal_graph_compute(
ne00 % 32 == 0 && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline]; break;
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline]; break;
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1513,12 +1544,14 @@ bool ggml_metal_graph_compute(
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
id<MTLComputePipelineState> pipeline = nil;
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
nrows = 4;
} break;
case GGML_TYPE_F16:
@ -1527,16 +1560,16 @@ bool ggml_metal_graph_compute(
nth1 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
nrows = 4;
}
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
nrows = 4;
}
} break;
@ -1544,61 +1577,61 @@ bool ggml_metal_graph_compute(
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
} break;
default:
{
@ -1611,6 +1644,7 @@ bool ggml_metal_graph_compute(
GGML_ASSERT(ne00 >= nth0*nth1);
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1712,21 +1746,26 @@ bool ggml_metal_graph_compute(
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) {
id<MTLComputePipelineState> pipeline = nil;
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline]; break;
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline]; break;
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1766,79 +1805,81 @@ bool ggml_metal_graph_compute(
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
id<MTLComputePipelineState> pipeline = nil;
// use custom matrix x vector kernel
switch (src2t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
} break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
} break;
default:
{
@ -1853,6 +1894,7 @@ bool ggml_metal_graph_compute(
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1915,23 +1957,26 @@ bool ggml_metal_graph_compute(
} break;
case GGML_OP_GET_ROWS:
{
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline]; break;
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline]; break;
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline]; break;
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1959,7 +2004,9 @@ bool ggml_metal_graph_compute(
nth *= 2;
}
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@ -1988,7 +2035,9 @@ bool ggml_metal_graph_compute(
// nth *= 2;
//}
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@ -2010,7 +2059,9 @@ bool ggml_metal_graph_compute(
const int nth = MIN(256, ne00);
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@ -2037,7 +2088,9 @@ bool ggml_metal_graph_compute(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@ -2082,12 +2135,15 @@ bool ggml_metal_graph_compute(
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline]; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
default: GGML_ASSERT(false);
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -2150,12 +2206,15 @@ bool ggml_metal_graph_compute(
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
@ -2209,7 +2268,9 @@ bool ggml_metal_graph_compute(
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@ -2242,12 +2303,15 @@ bool ggml_metal_graph_compute(
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
id<MTLComputePipelineState> pipeline = nil;
switch (order) {
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline]; break;
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline]; break;
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
default: GGML_ASSERT(false);
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@ -2261,7 +2325,9 @@ bool ggml_metal_graph_compute(
float slope;
memcpy(&slope, dst->op_params, sizeof(float));
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
@ -2278,33 +2344,36 @@ bool ggml_metal_graph_compute(
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline]; break;
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline]; break;
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
case GGML_TYPE_F16:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline]; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];