From e180fcd3d53cee6ed5ac82473b45c468379a0687 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Apr 2024 11:04:32 +0300 Subject: [PATCH] metal : fix max nsg ggml-ci --- ggml-metal.m | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 249b8312c..c6d580b84 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2643,13 +2643,25 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];