diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index ecfdf5465..8c20fc98f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -277,8 +277,6 @@ enum ggml_metal_kernel_type { }; struct ggml_backend_metal_context { - struct ggml_backend_metal_device_context ctx_dev; - id queue; dispatch_queue_t d_queue; @@ -343,7 +341,7 @@ static void * ggml_metal_host_malloc(size_t n) { return data; } -static struct ggml_backend_metal_context * ggml_metal_init(void) { +static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("%s: allocating\n", __func__); #if TARGET_OS_OSX && !GGML_METAL_NDEBUG @@ -357,8 +355,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { // init context struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); + struct ggml_backend_metal_device_context * ctx_dev = dev->context; - id device = ggml_backend_metal_device_acq(&ctx->ctx_dev); + id device = ggml_backend_metal_device_acq(ctx_dev); GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); ctx->queue = [device newCommandQueue]; @@ -482,9 +481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { } } - GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_reduction ? "true" : "false"); - GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_mm ? "true" : "false"); - GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->ctx_dev.mtl_device.hasUnifiedMemory ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); ctx->capture_next_compute = false; ctx->capture_started = false; @@ -536,8 +535,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } - const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm; - const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction; + const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; + const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; // simd_sum and simd_max requires MTLGPUFamilyApple7 @@ -740,7 +739,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { } [ctx->queue release]; - ggml_backend_metal_device_rel(&ctx->ctx_dev); dispatch_release(ctx->d_queue); @@ -798,15 +796,15 @@ static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs return nil; } -static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) { +static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { return false; } } - const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm; - const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction; + const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm; + const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; switch (op->op) { case GGML_OP_UNARY: @@ -921,9 +919,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx } static void ggml_metal_encode_node( - struct ggml_backend_metal_context * ctx, + ggml_backend_t backend, int idx, id encoder) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + struct ggml_cgraph * gf = ctx->gf; struct ggml_tensor * node = ggml_graph_node(gf, idx); @@ -953,7 +954,7 @@ static void ggml_metal_encode_node( } break; } - if (!ggml_metal_supports_op(ctx, dst)) { + if (!ggml_metal_supports_op(ctx_dev, dst)) { GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); GGML_ABORT("unsupported op"); } @@ -1026,7 +1027,7 @@ static void ggml_metal_encode_node( // dst->name); //} - id device = ctx->ctx_dev.mtl_device; + id device = ctx_dev->mtl_device; switch (dst->op) { case GGML_OP_CONCAT: @@ -3015,8 +3016,11 @@ static void ggml_metal_encode_node( } static enum ggml_status ggml_metal_graph_compute( - struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { + ggml_backend_t backend, + struct ggml_cgraph * gf) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + // number of nodes encoded by the main thread (empirically determined) const int n_main = 128; @@ -3044,7 +3048,7 @@ static enum ggml_status ggml_metal_graph_compute( if (!ctx->capture_started) { // create capture scope - ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->ctx_dev.mtl_device]; + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; descriptor.captureObject = ctx->capture_scope; @@ -3087,7 +3091,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]]; } - ggml_metal_encode_node(ctx, idx, encoder); + ggml_metal_encode_node(backend, idx, encoder); if (should_capture) { [encoder popDebugGroup]; @@ -3462,6 +3466,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) { static void ggml_backend_metal_free(ggml_backend_t backend) { struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + ggml_backend_metal_device_rel(ctx_dev); ggml_metal_free(ctx); free(backend); } @@ -3473,9 +3479,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm } static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context; - - return ggml_metal_graph_compute(metal_ctx, cgraph); + return ggml_metal_graph_compute(backend, cgraph); } static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { @@ -3522,8 +3526,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) { return &guid; } +// TODO: remove in the future ggml_backend_t ggml_backend_metal_init(void) { - struct ggml_backend_metal_context * ctx = ggml_metal_init(); + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); + + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); if (ctx == NULL) { GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return NULL; @@ -3534,7 +3541,7 @@ ggml_backend_t ggml_backend_metal_init(void) { *backend = (struct ggml_backend) { /* .guid = */ ggml_backend_metal_guid(), /* .interface = */ ggml_backend_metal_i, - /* .device = */ &g_ggml_backend_metal_device, + /* .device = */ dev, /* .context = */ ctx, }; @@ -3559,9 +3566,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { GGML_ASSERT(ggml_backend_is_metal(backend)); - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - return [ctx->ctx_dev.mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; + return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { @@ -3623,9 +3630,25 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g } static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { - return ggml_backend_metal_init(); + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); + + *backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; - GGML_UNUSED(dev); GGML_UNUSED(params); } @@ -3715,9 +3738,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)dev->context; + struct ggml_backend_metal_device_context * ctx_dev = dev->context; - return ggml_metal_supports_op(metal_ctx, op); + return ggml_metal_supports_op(ctx_dev, op); } static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {