metal : reuse dispatch queue + autoreleasepool

This commit is contained in:
Georgi Gerganov 2023-08-28 09:57:36 +03:00
parent 67dd7463ce
commit 43a8a6297b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -33,12 +33,12 @@ struct ggml_metal_buffer {
struct ggml_metal_context { struct ggml_metal_context {
int n_cb; int n_cb;
float * logits;
id<MTLDevice> device; id<MTLDevice> device;
id<MTLCommandQueue> queue; id<MTLCommandQueue> queue;
id<MTLLibrary> library; id<MTLLibrary> library;
dispatch_queue_t d_queue;
int n_buffers; int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
@ -120,6 +120,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0; ctx->n_buffers = 0;
ctx->concur_list_len = 0; ctx->concur_list_len = 0;
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
#if 0 #if 0
// compile from source string and show compile log // compile from source string and show compile log
@ -298,6 +299,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
[ctx->queue release]; [ctx->queue release];
[ctx->device release]; [ctx->device release];
dispatch_release(ctx->d_queue);
free(ctx); free(ctx);
} }
@ -563,6 +566,8 @@ void ggml_metal_graph_compute(
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__); metal_printf("%s: evaluating graph\n", __func__);
@autoreleasepool {
// if there is ctx->concur_list, dispatch concurrently // if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch // else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
@ -589,13 +594,10 @@ void ggml_metal_graph_compute(
command_encoders[i] = [command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; command_encoders[i] = [command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
} }
// TODO: is this the best way to start threads?
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_async(queue, ^{ dispatch_async(ctx->d_queue, ^{
size_t offs_src0 = 0; size_t offs_src0 = 0;
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute(
} }
// wait for all threads to finish // wait for all threads to finish
dispatch_barrier_sync(queue, ^{}); dispatch_barrier_sync(ctx->d_queue, ^{});
// check status of command buffers // check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881) // needed to detect if the device ran out-of-memory for example (#1881)
@ -1187,15 +1189,7 @@ void ggml_metal_graph_compute(
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status); fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false); GGML_ASSERT(false);
} }
[command_encoders[i] release];
[command_buffers[i] release];
} }
// release resources }
[edesc release];
[queue release];
[command_encoders release];
[command_buffers release];
} }