diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e2e987de..71e23f041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -155,9 +155,9 @@ if (APPLE AND LLAMA_ACCELERATE) endif() if (LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) message(STATUS "Metal framework found") set(GGML_HEADERS_METAL ggml-metal.h) diff --git a/ggml-metal.m b/ggml-metal.m index abe28bbb7..a673f0603 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - const int nth = MIN(1024, ne0); + const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00); [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1793,8 +1793,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct ggml_tensor * src_cur = dst->src[2 + j]; + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); @@ -1917,8 +1918,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct ggml_tensor * src_cur = dst->src[2 + j]; + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);