diff --git a/ggml-metal.m b/ggml-metal.m index e825b630b..1ab8ae8e2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -33,12 +33,12 @@ struct ggml_metal_buffer { struct ggml_metal_context { int n_cb; - float * logits; - id device; id queue; id library; + dispatch_queue_t d_queue; + int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; @@ -120,6 +120,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->n_buffers = 0; ctx->concur_list_len = 0; + ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); #if 0 // compile from source string and show compile log @@ -298,6 +299,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { [ctx->queue release]; [ctx->device release]; + dispatch_release(ctx->d_queue); + free(ctx); } @@ -563,6 +566,8 @@ void ggml_metal_graph_compute( struct ggml_cgraph * gf) { metal_printf("%s: evaluating graph\n", __func__); + @autoreleasepool { + // if there is ctx->concur_list, dispatch concurrently // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; @@ -589,13 +594,10 @@ void ggml_metal_graph_compute( command_encoders[i] = [command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; } - // TODO: is this the best way to start threads? - dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - dispatch_async(queue, ^{ + dispatch_async(ctx->d_queue, ^{ size_t offs_src0 = 0; size_t offs_src1 = 0; size_t offs_dst = 0; @@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute( } // wait for all threads to finish - dispatch_barrier_sync(queue, ^{}); + dispatch_barrier_sync(ctx->d_queue, ^{}); // check status of command buffers // needed to detect if the device ran out-of-memory for example (#1881) @@ -1187,15 +1189,7 @@ void ggml_metal_graph_compute( fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status); GGML_ASSERT(false); } - - [command_encoders[i] release]; - [command_buffers[i] release]; } - // release resources - [edesc release]; - [queue release]; - - [command_encoders release]; - [command_buffers release]; + } }