metal : reuse array for command buffers and encoders

This commit is contained in:
Georgi Gerganov 2023-08-28 10:49:27 +03:00
parent 43a8a6297b
commit fffd167069
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 13 additions and 12 deletions

View file

@ -24,6 +24,7 @@
// max memory buffers that can be mapped to the device // max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16 #define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor; struct ggml_tensor;
struct ggml_cgraph; struct ggml_cgraph;

View file

@ -37,6 +37,9 @@ struct ggml_metal_context {
id<MTLCommandQueue> queue; id<MTLCommandQueue> queue;
id<MTLLibrary> library; id<MTLLibrary> library;
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
dispatch_queue_t d_queue; dispatch_queue_t d_queue;
int n_buffers; int n_buffers;
@ -114,7 +117,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
ctx->n_cb = n_cb; ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->device = MTLCreateSystemDefaultDevice(); ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue]; ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0; ctx->n_buffers = 0;
@ -320,7 +323,7 @@ void ggml_metal_host_free(void * data) {
} }
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb; ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
} }
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@ -582,16 +585,13 @@ void ggml_metal_graph_compute(
const int n_cb = ctx->n_cb; const int n_cb = ctx->n_cb;
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
NSMutableArray * command_encoders = [NSMutableArray arrayWithCapacity:n_cb];
for (int i = 0; i < n_cb; ++i) { for (int i = 0; i < n_cb; ++i) {
command_buffers[i] = [ctx->queue commandBuffer]; ctx->command_buffers[i] = [ctx->queue commandBuffer];
// enqueue the command buffers in order to specify their execution order // enqueue the command buffers in order to specify their execution order
[command_buffers[i] enqueue]; [ctx->command_buffers[i] enqueue];
command_encoders[i] = [command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
} }
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
@ -602,8 +602,8 @@ void ggml_metal_graph_compute(
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx]; id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = command_encoders[cb_idx]; id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@ -1182,9 +1182,9 @@ void ggml_metal_graph_compute(
// check status of command buffers // check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881) // needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) { for (int i = 0; i < n_cb; i++) {
[command_buffers[i] waitUntilCompleted]; [ctx->command_buffers[i] waitUntilCompleted];
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status]; MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) { if (status != MTLCommandBufferStatusCompleted) {
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status); fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false); GGML_ASSERT(false);