Collecting command buffer completions on single thread

This commit is contained in:
Paul Tsochantaris 2024-01-14 18:36:55 +00:00
parent b95842ae4e
commit 3f787a4d5a

View file

@ -720,10 +720,11 @@ static bool ggml_metal_graph_compute(
const int n_nodes = gf->n_nodes; const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb; const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
__block BOOL all_buffers_succeeded = true; id<MTLCommandBuffer> command_buffers[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffers[cb_idx] = command_buffer;
// enqueue the command buffers in order to specify their execution order // enqueue the command buffers in order to specify their execution order
[command_buffer enqueue]; [command_buffer enqueue];
@ -2229,21 +2230,25 @@ static bool ggml_metal_graph_compute(
[encoder endEncoding]; [encoder endEncoding];
[command_buffer commit]; [command_buffer commit];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
all_buffers_succeeded = false;
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, cb_idx, status);
}
}); });
} }
// wait for all threads to finish // Wait for all command buffers to be committed
dispatch_barrier_sync(ctx->d_queue, ^{}); dispatch_barrier_sync(ctx->d_queue, ^{});
return all_buffers_succeeded; // Wait for completion and check status of each command buffer
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, cb_idx, status);
return false;
}
}
return true;
} }
} }