metal : increase concurrency nodes to 2*GGML_MAX_NODES

This commit is contained in:
Georgi Gerganov 2023-08-07 10:52:13 +03:00
parent 1038d1d2bc
commit 30ea0e1685
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -20,6 +20,8 @@
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
struct ggml_metal_buffer { struct ggml_metal_buffer {
const char * name; const char * name;
@ -41,7 +43,7 @@ struct ggml_metal_context {
int n_buffers; int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
int concur_list[GGML_MAX_NODES]; int concur_list[GGML_MAX_CONCUR];
int concur_list_len; int concur_list_len;
// custom kernels // custom kernels
@ -375,10 +377,10 @@ void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES]; int nodes_unused[GGML_MAX_CONCUR];
for (int i = 0; i < GGML_MAX_NODES; i++) { ctx->concur_list[i] = 0; } for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
ctx->concur_list_len = 0; ctx->concur_list_len = 0;
int n_left = gf->n_nodes; int n_left = gf->n_nodes;
@ -458,7 +460,7 @@ void ggml_metal_graph_find_concurrency(
level_pos += concurrency + 1; level_pos += concurrency + 1;
} }
if (ctx->concur_list_len > GGML_MAX_NODES) { if (ctx->concur_list_len > GGML_MAX_CONCUR) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
} }
} }
@ -472,7 +474,7 @@ void ggml_metal_graph_compute(
// else fallback to serial dispatch // else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;