Metal: Localized logic in ggml_metal_graph_compute, minor performance improvement

This commit is contained in:
Paul Tsochantaris 2024-01-13 23:15:28 +00:00
parent 76484fbfd3
commit b5f795f326
2 changed files with 19 additions and 38 deletions

View file

@ -27,7 +27,6 @@
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor;
struct ggml_cgraph;

View file

@ -170,9 +170,6 @@ struct ggml_metal_context {
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
dispatch_queue_t d_queue;
int n_buffers;
@ -715,41 +712,33 @@ static bool ggml_metal_graph_compute(
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const int n_nodes = gf->n_nodes;
edesc.dispatchType = MTLDispatchTypeSerial;
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb;
for (int i = 0; i < n_cb; ++i) {
ctx->command_buffers[i] = [ctx->queue commandBuffer];
// enqueue the command buffers in order to specify their execution order
[ctx->command_buffers[i] enqueue];
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
}
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
__block BOOL all_buffers_succeeded = true;
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
dispatch_async(ctx->d_queue, ^{
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
for (int ind = node_start; ind < node_end; ++ind) {
const int i = ind;
for (int i = node_start; i < node_end; ++i) {
if (i == -1) {
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
@ -2237,31 +2226,24 @@ static bool ggml_metal_graph_compute(
#endif
}
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
[encoder endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
all_buffers_succeeded = false;
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, cb_idx, status);
}
});
}
// wait for all threads to finish
dispatch_barrier_sync(ctx->d_queue, ^{});
// check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) {
[ctx->command_buffers[i] waitUntilCompleted];
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false;
}
}
return true;
return all_buffers_succeeded;
}
}