diff --git a/ggml-metal.h b/ggml-metal.h index 00202b787..fca28d37e 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -24,6 +24,7 @@ // max memory buffers that can be mapped to the device #define GGML_METAL_MAX_BUFFERS 16 +#define GGML_METAL_MAX_COMMAND_BUFFERS 32 struct ggml_tensor; struct ggml_cgraph; diff --git a/ggml-metal.m b/ggml-metal.m index 1ab8ae8e2..ad2ee8cf5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -37,6 +37,9 @@ struct ggml_metal_context { id queue; id library; + id command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS]; + id command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS]; + dispatch_queue_t d_queue; int n_buffers; @@ -114,7 +117,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); - ctx->n_cb = n_cb; + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); ctx->device = MTLCreateSystemDefaultDevice(); ctx->queue = [ctx->device newCommandQueue]; ctx->n_buffers = 0; @@ -320,7 +323,7 @@ void ggml_metal_host_free(void * data) { } void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { - ctx->n_cb = n_cb; + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); } int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { @@ -582,16 +585,13 @@ void ggml_metal_graph_compute( const int n_cb = ctx->n_cb; - NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb]; - NSMutableArray * command_encoders = [NSMutableArray arrayWithCapacity:n_cb]; - for (int i = 0; i < n_cb; ++i) { - command_buffers[i] = [ctx->queue commandBuffer]; + ctx->command_buffers[i] = [ctx->queue commandBuffer]; // enqueue the command buffers in order to specify their execution order - [command_buffers[i] enqueue]; + [ctx->command_buffers[i] enqueue]; - command_encoders[i] = [command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; + ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; } for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { @@ -602,8 +602,8 @@ void ggml_metal_graph_compute( size_t offs_src1 = 0; size_t offs_dst = 0; - id command_buffer = command_buffers[cb_idx]; - id encoder = command_encoders[cb_idx]; + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = ctx->command_encoders[cb_idx]; const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); @@ -1182,9 +1182,9 @@ void ggml_metal_graph_compute( // check status of command buffers // needed to detect if the device ran out-of-memory for example (#1881) for (int i = 0; i < n_cb; i++) { - [command_buffers[i] waitUntilCompleted]; + [ctx->command_buffers[i] waitUntilCompleted]; - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status]; + MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; if (status != MTLCommandBufferStatusCompleted) { fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status); GGML_ASSERT(false);