metal: issue operations concurrently if possible

Using the new ggml functions.
This commit is contained in:
lshzh-ww 2023-07-21 11:23:51 -04:00
parent 1c3030ee41
commit 6ee897a501
2 changed files with 21 additions and 16 deletions

View file

@ -364,7 +364,7 @@ void ggml_metal_graph_compute(
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
id<MTLComputeCommandEncoder> encoder = nil; id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
@ -434,10 +434,14 @@ void ggml_metal_graph_compute(
{ {
// noop // noop
} break; } break;
case GGML_OP_BARRIER:
{
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers | MTLBarrierScopeRenderTargets | MTLBarrierScopeTextures];
} break;
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
[encoder setComputePipelineState:ctx->pipeline_add]; [encoder setComputePipelineState:ctx->pipeline_add];
@ -452,7 +456,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
@ -473,7 +477,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
@ -490,7 +494,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SILU: case GGML_OP_SILU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
[encoder setComputePipelineState:ctx->pipeline_silu]; [encoder setComputePipelineState:ctx->pipeline_silu];
@ -504,7 +508,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RELU: case GGML_OP_RELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
[encoder setComputePipelineState:ctx->pipeline_relu]; [encoder setComputePipelineState:ctx->pipeline_relu];
@ -518,7 +522,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GELU: case GGML_OP_GELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
[encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setComputePipelineState:ctx->pipeline_gelu];
@ -532,7 +536,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const int nth = 32; const int nth = 32;
@ -550,7 +554,7 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const int n_past = ((int32_t *)(src1->data))[0]; const int n_past = ((int32_t *)(src1->data))[0];
@ -613,7 +617,7 @@ void ggml_metal_graph_compute(
} }
} else { } else {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
int nth0 = 32; int nth0 = 32;
@ -740,7 +744,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
switch (src0->type) { switch (src0->type) {
@ -769,7 +773,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const float eps = 1e-6f; const float eps = 1e-6f;
@ -791,7 +795,7 @@ void ggml_metal_graph_compute(
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const float eps = 1e-5f; const float eps = 1e-5f;
@ -813,7 +817,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
GGML_ASSERT((src0t == GGML_TYPE_F32)); GGML_ASSERT((src0t == GGML_TYPE_F32));
@ -855,7 +859,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
@ -898,7 +902,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} }
const int nth = 32; const int nth = 32;

View file

@ -1662,6 +1662,7 @@ static bool llama_eval_internal(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) { if (lctx.ctx_metal && N == 1) {
ggml_graph_find_concurrency(ctx0,&gf);
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, &gf); ggml_metal_graph_compute(lctx.ctx_metal, &gf);
ggml_metal_get_tensor (lctx.ctx_metal, cur); ggml_metal_get_tensor (lctx.ctx_metal, cur);