metal : avoid reference of device context in the backend context
ggml-ci
This commit is contained in:
parent
34e0e6eae4
commit
70ff50d753
1 changed files with 54 additions and 31 deletions
|
@ -277,8 +277,6 @@ enum ggml_metal_kernel_type {
|
|||
};
|
||||
|
||||
struct ggml_backend_metal_context {
|
||||
struct ggml_backend_metal_device_context ctx_dev;
|
||||
|
||||
id<MTLCommandQueue> queue;
|
||||
|
||||
dispatch_queue_t d_queue;
|
||||
|
@ -343,7 +341,7 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|||
return data;
|
||||
}
|
||||
|
||||
static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
||||
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
|
||||
GGML_LOG_INFO("%s: allocating\n", __func__);
|
||||
|
||||
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
||||
|
@ -357,8 +355,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|||
|
||||
// init context
|
||||
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(&ctx->ctx_dev);
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||
|
||||
ctx->queue = [device newCommandQueue];
|
||||
|
@ -482,9 +481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|||
}
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_reduction ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_mm ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->ctx_dev.mtl_device.hasUnifiedMemory ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
||||
|
||||
ctx->capture_next_compute = false;
|
||||
ctx->capture_started = false;
|
||||
|
@ -536,8 +535,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|||
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
||||
}
|
||||
|
||||
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
|
||||
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
|
||||
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
||||
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
||||
|
||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||
|
||||
|
@ -740,7 +739,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|||
}
|
||||
|
||||
[ctx->queue release];
|
||||
ggml_backend_metal_device_rel(&ctx->ctx_dev);
|
||||
|
||||
dispatch_release(ctx->d_queue);
|
||||
|
||||
|
@ -798,15 +796,15 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
|||
return nil;
|
||||
}
|
||||
|
||||
static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
|
||||
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
||||
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
|
||||
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
|
||||
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
||||
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
|
@ -921,9 +919,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||
}
|
||||
|
||||
static void ggml_metal_encode_node(
|
||||
struct ggml_backend_metal_context * ctx,
|
||||
ggml_backend_t backend,
|
||||
int idx,
|
||||
id<MTLComputeCommandEncoder> encoder) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
struct ggml_cgraph * gf = ctx->gf;
|
||||
|
||||
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
||||
|
@ -953,7 +954,7 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
}
|
||||
|
||||
if (!ggml_metal_supports_op(ctx, dst)) {
|
||||
if (!ggml_metal_supports_op(ctx_dev, dst)) {
|
||||
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ABORT("unsupported op");
|
||||
}
|
||||
|
@ -1026,7 +1027,7 @@ static void ggml_metal_encode_node(
|
|||
// dst->name);
|
||||
//}
|
||||
|
||||
id<MTLDevice> device = ctx->ctx_dev.mtl_device;
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_CONCAT:
|
||||
|
@ -3015,8 +3016,11 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
|
||||
static enum ggml_status ggml_metal_graph_compute(
|
||||
struct ggml_backend_metal_context * ctx,
|
||||
struct ggml_cgraph * gf) {
|
||||
ggml_backend_t backend,
|
||||
struct ggml_cgraph * gf) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = 128;
|
||||
|
||||
|
@ -3044,7 +3048,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
if (!ctx->capture_started) {
|
||||
// create capture scope
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->ctx_dev.mtl_device];
|
||||
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
|
||||
|
||||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||
descriptor.captureObject = ctx->capture_scope;
|
||||
|
@ -3087,7 +3091,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
|
||||
ggml_metal_encode_node(ctx, idx, encoder);
|
||||
ggml_metal_encode_node(backend, idx, encoder);
|
||||
|
||||
if (should_capture) {
|
||||
[encoder popDebugGroup];
|
||||
|
@ -3462,6 +3466,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|||
|
||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
ggml_metal_free(ctx);
|
||||
free(backend);
|
||||
}
|
||||
|
@ -3473,9 +3479,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
|
|||
}
|
||||
|
||||
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
|
||||
|
||||
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
||||
return ggml_metal_graph_compute(backend, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
@ -3522,8 +3526,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
|
|||
return &guid;
|
||||
}
|
||||
|
||||
// TODO: remove in the future
|
||||
ggml_backend_t ggml_backend_metal_init(void) {
|
||||
struct ggml_backend_metal_context * ctx = ggml_metal_init();
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
|
||||
|
||||
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
|
@ -3534,7 +3541,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|||
*backend = (struct ggml_backend) {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ &g_ggml_backend_metal_device,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
|
@ -3559,9 +3566,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
|
|||
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
return [ctx->ctx_dev.mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
}
|
||||
|
||||
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
||||
|
@ -3623,9 +3630,25 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g
|
|||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
return ggml_backend_metal_init();
|
||||
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
|
||||
if (ctx == NULL) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
||||
|
||||
*backend = (struct ggml_backend) {
|
||||
/* .guid = */ ggml_backend_metal_guid(),
|
||||
/* .interface = */ ggml_backend_metal_i,
|
||||
/* .device = */ dev,
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
ggml_backend_metal_set_n_cb(backend, 1);
|
||||
|
||||
return backend;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
|
@ -3715,9 +3738,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)dev->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||
|
||||
return ggml_metal_supports_op(metal_ctx, op);
|
||||
return ggml_metal_supports_op(ctx_dev, op);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue