diff --git a/ggml-metal.m b/ggml-metal.m index e73fb2517..74a6bff40 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -451,13 +451,13 @@ void ggml_metal_graph_compute( // if there is ctx->concur_list, dispatch concurrently // else fallback to serial dispatch - MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor; - encoder_descriptor.dispatchType = MTLDispatchTypeSerial; - int all_nodes_len = gf->n_nodes; - if (ctx->concur_list_len) { - encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent; - all_nodes_len = ctx->concur_list_len; - } + MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; + + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; + + const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; + edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; + // create multiple command buffers and enqueue them // then, we encode the graph into the command buffers in parallel @@ -476,7 +476,7 @@ void ggml_metal_graph_compute( dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb; + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; dispatch_async(queue, ^{ size_t offs_src0 = 0; @@ -487,22 +487,21 @@ void ggml_metal_graph_compute( id encoder = nil; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; for (int ind = node_start; ind < node_end; ++ind) { - int i = ind; - if (ctx->concur_list_len) { - i = ctx->concur_list[ind]; - } + const int i = has_concur ? ctx->concur_list[ind] : ind; + if (i == -1) { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; continue; } [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; continue; } + metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); struct ggml_tensor * src0 = gf->nodes[i]->src[0]; @@ -573,7 +572,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } if (ggml_nelements(src1) == ne10) { @@ -594,7 +593,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } if (ggml_nelements(src1) == ne10) { @@ -615,7 +614,7 @@ void ggml_metal_graph_compute( case GGML_OP_SCALE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float scale = *(const float *) src1->data; @@ -634,7 +633,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_SILU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_silu]; @@ -648,7 +647,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_RELU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_relu]; @@ -662,7 +661,7 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_GELU: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } [encoder setComputePipelineState:ctx->pipeline_gelu]; @@ -682,7 +681,7 @@ void ggml_metal_graph_compute( case GGML_OP_SOFT_MAX: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32; @@ -700,7 +699,7 @@ void ggml_metal_graph_compute( case GGML_OP_DIAG_MASK_INF: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int n_past = ((int32_t *)(dst->op_params))[0]; @@ -763,7 +762,7 @@ void ggml_metal_graph_compute( } } else { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } int nth0 = 32; @@ -890,7 +889,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } switch (src0->type) { @@ -919,7 +918,7 @@ void ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } float eps; @@ -942,7 +941,7 @@ void ggml_metal_graph_compute( case GGML_OP_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float eps = 1e-5f; @@ -964,7 +963,7 @@ void ggml_metal_graph_compute( case GGML_OP_ALIBI: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } GGML_ASSERT((src0t == GGML_TYPE_F32)); @@ -1007,7 +1006,7 @@ void ggml_metal_graph_compute( case GGML_OP_ROPE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int n_past = ((int32_t *) dst->op_params)[0]; @@ -1051,7 +1050,7 @@ void ggml_metal_graph_compute( case GGML_OP_CONT: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32;