diff --git a/ggml-metal.m b/ggml-metal.m index 3b5250710..d23f2c9e7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1573,6 +1573,7 @@ static enum ggml_status ggml_metal_graph_compute( //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); id pipeline = nil; + bool is_large = false; // use custom matrix x vector kernel switch (src0t) { @@ -1592,6 +1593,7 @@ static enum ggml_status ggml_metal_graph_compute( } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { if (ne01 > 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline; + is_large = true; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; } @@ -1784,7 +1786,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { const int64_t ny = (ne11 + nrows - 1)/nrows; - if (ne01 > 128) { + if (is_large) { [encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];