metal : normalize encoder:setComputePipelineStatus calls
ggml-ci
This commit is contained in:
parent
fef1dbf2eb
commit
77bb72cd8c
1 changed files with 170 additions and 101 deletions
271
ggml-metal.m
271
ggml-metal.m
|
@ -1051,7 +1051,9 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
const int64_t nb = ne00;
|
const int64_t nb = ne00;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1183,11 +1185,8 @@ bool ggml_metal_graph_compute(
|
||||||
// not sure how to avoid this
|
// not sure how to avoid this
|
||||||
// TODO: make a simpler cpy_bytes kernel
|
// TODO: make a simpler cpy_bytes kernel
|
||||||
|
|
||||||
// pipeline
|
|
||||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
||||||
|
|
||||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
@ -1208,6 +1207,8 @@ bool ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
|
||||||
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1255,13 +1256,16 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
int64_t n = ggml_nelements(dst);
|
int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
if (n % 4 == 0) {
|
if (n % 4 == 0) {
|
||||||
n /= 4;
|
n /= 4;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
@ -1272,7 +1276,9 @@ bool ggml_metal_graph_compute(
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
@ -1282,7 +1288,9 @@ bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
@ -1292,7 +1300,9 @@ bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
@ -1303,7 +1313,9 @@ bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
@ -1314,7 +1326,9 @@ bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
@ -1333,18 +1347,23 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
@ -1378,20 +1397,23 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
while (nth < ne00/4 && nth < 256) {
|
while (nth < ne00/4 && nth < 256) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
||||||
} else {
|
} else {
|
||||||
while (nth < ne00 && nth < 1024) {
|
while (nth < ne00 && nth < 1024) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float scale = ((float *) dst->op_params)[0];
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
if (id_src1) {
|
if (id_src1) {
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
@ -1411,11 +1433,15 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
if (ne00%8 == 0) {
|
if (ne00%8 == 0) {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
@ -1475,21 +1501,26 @@ bool ggml_metal_graph_compute(
|
||||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||||
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
||||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline]; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline]; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline]; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline]; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline]; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline]; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline]; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline]; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline]; break;
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline]; break;
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1513,12 +1544,14 @@ bool ggml_metal_graph_compute(
|
||||||
int nrows = 1;
|
int nrows = 1;
|
||||||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
@ -1527,16 +1560,16 @@ bool ggml_metal_graph_compute(
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (src1t == GGML_TYPE_F32) {
|
if (src1t == GGML_TYPE_F32) {
|
||||||
if (ne11 * ne12 < 4) {
|
if (ne11 * ne12 < 4) {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
||||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
||||||
nrows = ne11;
|
nrows = ne11;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -1544,61 +1577,61 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
{
|
{
|
||||||
nth0 = 4; //1;
|
nth0 = 4; //1;
|
||||||
nth1 = 8; //32;
|
nth1 = 8; //32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -1611,6 +1644,7 @@ bool ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(ne00 >= nth0*nth1);
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1712,21 +1746,26 @@ bool ggml_metal_graph_compute(
|
||||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne20 % 32 == 0 && ne20 >= 64 &&
|
ne20 % 32 == 0 && ne20 >= 64 &&
|
||||||
ne11 > ne11_mm_min) {
|
ne11 > ne11_mm_min) {
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src2->type) {
|
switch (src2->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline]; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline]; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline]; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline]; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline]; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline]; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline]; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline]; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline]; break;
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline]; break;
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1766,79 +1805,81 @@ bool ggml_metal_graph_compute(
|
||||||
int nrows = 1;
|
int nrows = 1;
|
||||||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src2t) {
|
switch (src2t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
nth0 = 32;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
{
|
{
|
||||||
nth0 = 4; //1;
|
nth0 = 4; //1;
|
||||||
nth1 = 8; //32;
|
nth1 = 8; //32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
{
|
{
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline];
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -1853,6 +1894,7 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1915,23 +1957,26 @@ bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline]; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline]; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline]; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline]; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline]; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline]; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline]; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline]; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K].pipeline; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline]; break;
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K].pipeline; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline]; break;
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K].pipeline; break;
|
||||||
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline]; break;
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -1959,7 +2004,9 @@ bool ggml_metal_graph_compute(
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
@ -1988,7 +2035,9 @@ bool ggml_metal_graph_compute(
|
||||||
// nth *= 2;
|
// nth *= 2;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
@ -2010,7 +2059,9 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
const int nth = MIN(256, ne00);
|
const int nth = MIN(256, ne00);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
@ -2037,7 +2088,9 @@ bool ggml_metal_graph_compute(
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
@ -2082,12 +2135,15 @@ bool ggml_metal_graph_compute(
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
@ -2150,12 +2206,15 @@ bool ggml_metal_graph_compute(
|
||||||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||||
|
@ -2209,7 +2268,9 @@ bool ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
@ -2242,12 +2303,15 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (order) {
|
switch (order) {
|
||||||
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline]; break;
|
case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
||||||
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline]; break;
|
case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
@ -2261,7 +2325,9 @@ bool ggml_metal_graph_compute(
|
||||||
float slope;
|
float slope;
|
||||||
memcpy(&slope, dst->op_params, sizeof(float));
|
memcpy(&slope, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline];
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
||||||
|
@ -2278,33 +2344,36 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||||
|
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline]; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline]; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline]; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||||
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline]; break;
|
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||||
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline]; break;
|
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline]; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue