metal : fix Metal API debug warnings
This commit is contained in:
parent
75c14f2608
commit
515cfec44f
2 changed files with 10 additions and 8 deletions
|
@ -155,9 +155,9 @@ if (APPLE AND LLAMA_ACCELERATE)
|
|||
endif()
|
||||
|
||||
if (LLAMA_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
|
||||
message(STATUS "Metal framework found")
|
||||
set(GGML_HEADERS_METAL ggml-metal.h)
|
||||
|
|
12
ggml-metal.m
12
ggml-metal.m
|
@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
|
|||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -1793,8 +1793,9 @@ void ggml_metal_graph_compute(
|
|||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < n_as; ++j) {
|
||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||
|
@ -1917,8 +1918,9 @@ void ggml_metal_graph_compute(
|
|||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < n_as; ++j) {
|
||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue