metal : avoid reference of device context in the backend context

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-07 11:46:34 +03:00
parent 34e0e6eae4
commit 70ff50d753
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -277,8 +277,6 @@ enum ggml_metal_kernel_type {
};
struct ggml_backend_metal_context {
struct ggml_backend_metal_device_context ctx_dev;
id<MTLCommandQueue> queue;
dispatch_queue_t d_queue;
@ -343,7 +341,7 @@ static void * ggml_metal_host_malloc(size_t n) {
return data;
}
static struct ggml_backend_metal_context * ggml_metal_init(void) {
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("%s: allocating\n", __func__);
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@ -357,8 +355,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
// init context
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(&ctx->ctx_dev);
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
ctx->queue = [device newCommandQueue];
@ -482,9 +481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
}
}
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_reduction ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_mm ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->ctx_dev.mtl_device.hasUnifiedMemory ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
ctx->capture_next_compute = false;
ctx->capture_started = false;
@ -536,8 +535,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
}
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
// simd_sum and simd_max requires MTLGPUFamilyApple7
@ -740,7 +739,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
}
[ctx->queue release];
ggml_backend_metal_device_rel(&ctx->ctx_dev);
dispatch_release(ctx->d_queue);
@ -798,15 +796,15 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
return nil;
}
static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
}
}
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
switch (op->op) {
case GGML_OP_UNARY:
@ -921,9 +919,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
}
static void ggml_metal_encode_node(
struct ggml_backend_metal_context * ctx,
ggml_backend_t backend,
int idx,
id<MTLComputeCommandEncoder> encoder) {
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
struct ggml_cgraph * gf = ctx->gf;
struct ggml_tensor * node = ggml_graph_node(gf, idx);
@ -953,7 +954,7 @@ static void ggml_metal_encode_node(
} break;
}
if (!ggml_metal_supports_op(ctx, dst)) {
if (!ggml_metal_supports_op(ctx_dev, dst)) {
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
GGML_ABORT("unsupported op");
}
@ -1026,7 +1027,7 @@ static void ggml_metal_encode_node(
// dst->name);
//}
id<MTLDevice> device = ctx->ctx_dev.mtl_device;
id<MTLDevice> device = ctx_dev->mtl_device;
switch (dst->op) {
case GGML_OP_CONCAT:
@ -3015,8 +3016,11 @@ static void ggml_metal_encode_node(
}
static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
ggml_backend_t backend,
struct ggml_cgraph * gf) {
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;
@ -3044,7 +3048,7 @@ static enum ggml_status ggml_metal_graph_compute(
if (!ctx->capture_started) {
// create capture scope
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->ctx_dev.mtl_device];
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->capture_scope;
@ -3087,7 +3091,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(ctx, idx, encoder);
ggml_metal_encode_node(backend, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
@ -3462,6 +3466,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
static void ggml_backend_metal_free(ggml_backend_t backend) {
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
ggml_backend_metal_device_rel(ctx_dev);
ggml_metal_free(ctx);
free(backend);
}
@ -3473,9 +3479,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
}
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
return ggml_metal_graph_compute(metal_ctx, cgraph);
return ggml_metal_graph_compute(backend, cgraph);
}
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@ -3522,8 +3526,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
return &guid;
}
// TODO: remove in the future
ggml_backend_t ggml_backend_metal_init(void) {
struct ggml_backend_metal_context * ctx = ggml_metal_init();
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
if (ctx == NULL) {
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
@ -3534,7 +3541,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
*backend = (struct ggml_backend) {
/* .guid = */ ggml_backend_metal_guid(),
/* .interface = */ ggml_backend_metal_i,
/* .device = */ &g_ggml_backend_metal_device,
/* .device = */ dev,
/* .context = */ ctx,
};
@ -3559,9 +3566,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
return [ctx->ctx_dev.mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
@ -3623,9 +3630,25 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g
}
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
return ggml_backend_metal_init();
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
if (ctx == NULL) {
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
}
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
*backend = (struct ggml_backend) {
/* .guid = */ ggml_backend_metal_guid(),
/* .interface = */ ggml_backend_metal_i,
/* .device = */ dev,
/* .context = */ ctx,
};
ggml_backend_metal_set_n_cb(backend, 1);
return backend;
GGML_UNUSED(dev);
GGML_UNUSED(params);
}
@ -3715,9 +3738,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
}
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)dev->context;
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
return ggml_metal_supports_op(metal_ctx, op);
return ggml_metal_supports_op(ctx_dev, op);
}
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {