metal : increase concurrency nodes to 2*GGML_MAX_NODES

This commit is contained in:
Georgi Gerganov 2023-08-07 10:52:13 +03:00
parent 1038d1d2bc
commit 30ea0e1685
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -20,6 +20,8 @@
#define UNUSED(x) (void)(x)
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
struct ggml_metal_buffer {
const char * name;
@ -41,7 +43,7 @@ struct ggml_metal_context {
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
int concur_list[GGML_MAX_NODES];
int concur_list[GGML_MAX_CONCUR];
int concur_list_len;
// custom kernels
@ -375,10 +377,10 @@ void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES];
int nodes_unused[GGML_MAX_CONCUR];
for (int i = 0; i < GGML_MAX_NODES; i++) { ctx->concur_list[i] = 0; }
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
ctx->concur_list_len = 0;
int n_left = gf->n_nodes;
@ -458,7 +460,7 @@ void ggml_metal_graph_find_concurrency(
level_pos += concurrency + 1;
}
if (ctx->concur_list_len > GGML_MAX_NODES) {
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}
@ -472,7 +474,7 @@ void ggml_metal_graph_compute(
// else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;