From fa5a989104d210a33f3a7f5dbcf3c82747d8be87 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 8 Sep 2023 12:06:23 +0200 Subject: [PATCH] metal: faster diagonal infinity Although, to me it looks like one should simply fuse scale + diagnonal infinity + soft_max on the KQtensor. --- ggml-metal.m | 8 +++++--- ggml-metal.metal | 29 +++++++++++++++++++---------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index de2b1dbed..81edc5046 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -815,15 +815,17 @@ void ggml_metal_graph_compute( case GGML_OP_DIAG_MASK_INF: { const int n_past = ((int32_t *)(dst->op_params))[0]; + const int64_t n00x01 = ne00*ne01; + assert((n00x01*ne02)%8 == 0); [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&n00x01 length:sizeof(n00x01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n00x01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_MUL_MAT: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 531585ebf..2da0a1459 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -183,20 +183,29 @@ kernel void kernel_soft_max_4( } kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, constant int64_t & ne00, - constant int64_t & ne01, + constant int64_t & n00x01, constant int & n_past, uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; + const int64_t i = 2*tpig[0]; - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/n00x01; + i4 -= i02*n00x01; + const int64_t i01 = i4/ne00; + const int64_t i00 = i4 - i01*ne00; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } } }