Single allocation of encode_async block with non-ARC capture in ggml-metal.m
This commit is contained in:
parent
71967c2a6d
commit
7403c05c06
1 changed files with 41 additions and 44 deletions
|
@ -438,6 +438,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
||||||
ctx->capture_scope = nil;
|
ctx->capture_scope = nil;
|
||||||
|
|
||||||
ctx->gf = nil;
|
ctx->gf = nil;
|
||||||
|
Block_release(ctx->encode_async);
|
||||||
ctx->encode_async = nil;
|
ctx->encode_async = nil;
|
||||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
ctx->command_buffers[i] = nil;
|
ctx->command_buffers[i] = nil;
|
||||||
|
@ -3000,46 +3001,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
|
|
||||||
ctx->encode_async = ^(size_t iter) {
|
|
||||||
const int cb_idx = iter;
|
|
||||||
const int n_cb_l = ctx->n_cb;
|
|
||||||
|
|
||||||
const int n_nodes_0 = ctx->n_nodes_0;
|
|
||||||
const int n_nodes_1 = ctx->n_nodes_1;
|
|
||||||
|
|
||||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
|
||||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
|
||||||
|
|
||||||
int node_start = 0;
|
|
||||||
int node_end = n_nodes_0;
|
|
||||||
|
|
||||||
if (cb_idx < n_cb_l) {
|
|
||||||
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
|
||||||
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int idx = node_start; idx < node_end; ++idx) {
|
|
||||||
if (should_capture) {
|
|
||||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_metal_encode_node(ctx, idx, encoder);
|
|
||||||
|
|
||||||
if (should_capture) {
|
|
||||||
[encoder popDebugGroup];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
[encoder endEncoding];
|
|
||||||
|
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
||||||
[command_buffer commit];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// the main thread commits the first few commands immediately
|
// the main thread commits the first few commands immediately
|
||||||
// command_buffer[n_cb]
|
// command_buffer[n_cb]
|
||||||
{
|
{
|
||||||
|
@ -3468,10 +3429,46 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
|
ctx->encode_async = Block_copy(^(size_t iter) {
|
||||||
//ctx->encode_async = ^(size_t iter) {
|
const int cb_idx = iter;
|
||||||
// ...
|
const int n_cb_l = ctx->n_cb;
|
||||||
//};
|
|
||||||
|
const int n_nodes_0 = ctx->n_nodes_0;
|
||||||
|
const int n_nodes_1 = ctx->n_nodes_1;
|
||||||
|
|
||||||
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||||
|
|
||||||
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
||||||
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
||||||
|
|
||||||
|
int node_start = 0;
|
||||||
|
int node_end = n_nodes_0;
|
||||||
|
|
||||||
|
if (cb_idx < n_cb_l) {
|
||||||
|
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
||||||
|
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool should_capture = ctx->capture_next_compute;
|
||||||
|
|
||||||
|
for (int idx = node_start; idx < node_end; ++idx) {
|
||||||
|
if (should_capture) {
|
||||||
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_encode_node(ctx, idx, encoder);
|
||||||
|
|
||||||
|
if (should_capture) {
|
||||||
|
[encoder popDebugGroup];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder endEncoding];
|
||||||
|
|
||||||
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
|
[command_buffer commit];
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_i ggml_backend_metal_i = {
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue