metal : reduce command encoding overhead

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-01 10:23:44 +03:00
parent 1927378bcc
commit 43b9d694df
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 1992 additions and 1907 deletions

View file

@ -204,13 +204,6 @@ static ggml_status compute_piter(
ggml_backend_cpu_set_n_threads(model.backend, params.n_threads);
}
// TODO: enable GPU support when support for GGML_OP_SQRT is added
//#ifdef GGML_USE_METAL
// if (ggml_backend_is_metal(model.backend)) {
// ggml_backend_metal_set_n_cb(model.backend, params.n_threads);
// }
//#endif
ggml_status res = ggml_backend_graph_compute(model.backend, gf);
if (res == GGML_STATUS_SUCCESS) {
auto extract_i = [](std::string prefix, std::string str) -> int {

View file

@ -2444,12 +2444,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(ctx->backend)) {
ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
}
#endif
ggml_backend_graph_compute(ctx->backend, gf);
// the last node is the embedding tensor

View file

@ -25,9 +25,6 @@
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
struct ggml_tensor;
struct ggml_cgraph;
@ -48,8 +45,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);

View file

@ -12,6 +12,11 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 128
#ifdef GGML_METAL_NDEBUG
#define GGML_METAL_LOG(...)
#define GGML_METAL_LOG_INFO(...)
@ -226,6 +231,8 @@ struct ggml_backend_metal_context {
id<MTLDevice> device;
id<MTLCommandQueue> queue;
MTLComputePassDescriptor * edesc;
dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@ -234,6 +241,19 @@ struct ggml_backend_metal_context {
bool support_simdgroup_mm;
bool should_capture_next_compute;
bool capture_started;
id<MTLCaptureScope> cap_scope;
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
int n_nodes_0;
int n_nodes_1;
int n_nodes_per_cb;
struct ggml_cgraph * gf;
void (^encode_async)(size_t ith);
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
@ -303,7 +323,7 @@ static void * ggml_metal_host_malloc(size_t n) {
return data;
}
static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
static struct ggml_backend_metal_context * ggml_metal_init(void) {
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@ -322,8 +342,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
// Configure context
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
ctx->device = device;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->queue = [ctx->device newCommandQueue];
ctx->edesc = MTLComputePassDescriptor.computePassDescriptor;
ctx->edesc.dispatchType = MTLDispatchTypeSerial;
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library;
@ -456,6 +477,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
ctx->should_capture_next_compute = false;
ctx->capture_started = false;
ctx->cap_scope = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->command_buffers[i] = nil;
}
ctx->encode_async = nil;
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
@ -686,6 +716,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
}
[metal_library release];
return ctx;
}
@ -874,78 +905,23 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
}
}
static enum ggml_status ggml_metal_graph_compute(
static void ggml_metal_encode_node(
struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
int idx,
id<MTLComputeCommandEncoder> encoder) {
struct ggml_cgraph * gf = ctx->gf;
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
edesc.dispatchType = MTLDispatchTypeSerial;
struct ggml_tensor * node = ggml_graph_node(gf, idx);
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
const bool should_capture = ctx->should_capture_next_compute;
if (should_capture) {
ctx->should_capture_next_compute = false;
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->queue;
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
GGML_ABORT("capture failed");
}
}
id<MTLCommandBuffer> command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
}
}
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
const int cb_idx = iter;
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
for (int i = node_start; i < node_end; ++i) {
if (i == -1) {
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * dst = gf->nodes[i];
struct ggml_tensor * src0 = node->src[0];
struct ggml_tensor * src1 = node->src[1];
struct ggml_tensor * src2 = node->src[2];
struct ggml_tensor * dst = node;
if (ggml_is_empty(dst)) {
continue;
return;
}
switch (dst->op) {
@ -956,7 +932,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_PERMUTE:
{
// noop -> next node
} continue;
} return;
default:
{
} break;
@ -967,10 +943,6 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ABORT("unsupported op");
}
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
}
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
@ -1015,6 +987,11 @@ static enum ggml_status ggml_metal_graph_compute(
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
@ -1039,7 +1016,7 @@ static enum ggml_status ggml_metal_graph_compute(
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
const int32_t dim = ((int32_t *) dst->op_params)[0];
const int32_t dim = ((const int32_t *) dst->op_params)[0];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1203,12 +1180,12 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
const size_t offs = ((int32_t *) dst->op_params)[3];
const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
const size_t offs = ((const int32_t *) dst->op_params)[3];
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
@ -1309,8 +1286,8 @@ static enum ggml_status ggml_metal_graph_compute(
float min;
float max;
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1323,7 +1300,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
switch (ggml_get_unary_op(node)) {
// we are not taking into account the strides, so for now require contiguous tensors
GGML_ASSERT(ggml_is_contiguous(src0));
@ -1422,7 +1399,7 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
GGML_ABORT("fatal error");
}
} break;
@ -1551,8 +1528,8 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
@ -1585,7 +1562,7 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_DIAG_MASK_INF:
{
const int n_past = ((int32_t *)(dst->op_params))[0];
const int n_past = ((const int32_t *)(dst->op_params))[0];
id<MTLComputePipelineState> pipeline = nil;
@ -1644,9 +1621,9 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_SSM_SCAN:
{
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
@ -2425,7 +2402,7 @@ static enum ggml_status ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
int nth = 32; // SIMD width
@ -2479,11 +2456,11 @@ static enum ggml_status ggml_metal_graph_compute(
const int nth = MIN(1024, ne00);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_past = ((const int32_t *) dst->op_params)[0];
const int n_dims = ((const int32_t *) dst->op_params)[1];
const int mode = ((const int32_t *) dst->op_params)[2];
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
float freq_base;
float freq_scale;
@ -2492,12 +2469,12 @@ static enum ggml_status ggml_metal_graph_compute(
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@ -2686,8 +2663,8 @@ static enum ggml_status ggml_metal_graph_compute(
float start;
float step;
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
@ -2786,7 +2763,7 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ggml_are_same_shape (src1, src2));
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
struct ggml_tensor * src3 = node->src[3];
size_t offs_src3 = 0;
@ -2811,9 +2788,9 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
float logit_softcap;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
@ -3014,10 +2991,81 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
GGML_ABORT("fatal error");
}
}
}
static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
@autoreleasepool {
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
const int n_cb = ctx->n_cb;
ctx->gf = gf;
ctx->n_nodes_0 = MIN(128, gf->n_nodes);
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + n_cb - 1) / n_cb;
//const int64_t t_start = ggml_time_us();
const bool should_capture = ctx->should_capture_next_compute;
if (should_capture) {
ctx->should_capture_next_compute = false;
if (!ctx->capture_started) {
// create capture scope
ctx->cap_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->cap_scope;
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
GGML_ABORT("capture failed");
} else {
[ctx->cap_scope beginScope];
ctx->capture_started = true;
}
}
}
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
ctx->encode_async = ^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
const int n_nodes_0 = ctx->n_nodes_0;
const int n_nodes_1 = ctx->n_nodes_1;
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: ctx->edesc];
int node_start = 0;
int node_end = n_nodes_0;
if ((int) iter < n_cb_l) {
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
}
for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(ctx, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
@ -3029,13 +3077,55 @@ static enum ggml_status ggml_metal_graph_compute(
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
}
});
};
{
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[n_cb] = command_buffer;
[command_buffer enqueue];
ctx->encode_async(n_cb);
}
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[cb_idx] = command_buffer;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
}
}
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
//{
// const int64_t t_end = ggml_time_us();
// //printf("time to encode: %d us, n_cb = %d\n", (int) (t_end - t_start), n_cb);
//}
// Wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
}
}
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> command_buffer = command_buffers[i];
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
@ -3048,7 +3138,7 @@ static enum ggml_status ggml_metal_graph_compute(
return GGML_STATUS_FAILED;
}
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
if (!next_buffer) {
continue;
}
@ -3066,11 +3156,17 @@ static enum ggml_status ggml_metal_graph_compute(
[next_buffer commit];
}
if (should_capture) {
//{
// const int64_t t_end = ggml_time_us();
// printf("time to compute: %d us\n", (int)(t_end - t_start));
//}
if (!should_capture && ctx->capture_started) {
[ctx->cap_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];
}
}
return GGML_STATUS_SUCCESS;
}
@ -3405,6 +3501,25 @@ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, g
UNUSED(backend);
}
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
if (ctx->n_cb != n_cb) {
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
if (ctx->n_cb > 2) {
GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
}
}
// TODO: setting encode_async here causes crash. why?
//ctx->encode_async = ^(size_t iter) {
// ...
//};
}
static struct ggml_backend_i ggml_backend_metal_i = {
/* .get_name = */ ggml_backend_metal_name,
/* .free = */ ggml_backend_metal_free,
@ -3439,35 +3554,29 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
}
ggml_backend_t ggml_backend_metal_init(void) {
struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
struct ggml_backend_metal_context * ctx = ggml_metal_init();
if (ctx == NULL) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
}
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
*metal_backend = (struct ggml_backend) {
*backend = (struct ggml_backend) {
/* .guid = */ ggml_backend_metal_guid(),
/* .interface = */ ggml_backend_metal_i,
/* .context = */ ctx,
};
return metal_backend;
ggml_backend_metal_set_n_cb(backend, 1);
return backend;
}
bool ggml_backend_is_metal(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
}
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
GGML_ASSERT(ggml_backend_is_metal(backend));

View file

@ -17023,12 +17023,6 @@ static void llama_graph_compute(
ggml_cgraph * gf,
int n_threads,
ggml_threadpool * threadpool) {
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(lctx.backend_metal)) {
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
}
#endif
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);