metal : code style changes

This commit is contained in:
Georgi Gerganov 2023-07-25 14:59:17 +03:00
parent ea02f675f9
commit 141d88d916
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -451,13 +451,13 @@ void ggml_metal_graph_compute(
// if there is ctx->concur_list, dispatch concurrently // if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch // else fallback to serial dispatch
MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
encoder_descriptor.dispatchType = MTLDispatchTypeSerial;
int all_nodes_len = gf->n_nodes; const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
if (ctx->concur_list_len) {
encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
all_nodes_len = ctx->concur_list_len; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
}
// create multiple command buffers and enqueue them // create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel // then, we encode the graph into the command buffers in parallel
@ -476,7 +476,7 @@ void ggml_metal_graph_compute(
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb; const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_async(queue, ^{ dispatch_async(queue, ^{
size_t offs_src0 = 0; size_t offs_src0 = 0;
@ -487,22 +487,21 @@ void ggml_metal_graph_compute(
id<MTLComputeCommandEncoder> encoder = nil; id<MTLComputeCommandEncoder> encoder = nil;
const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb; const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
for (int ind = node_start; ind < node_end; ++ind) { for (int ind = node_start; ind < node_end; ++ind) {
int i = ind; const int i = has_concur ? ctx->concur_list[ind] : ind;
if (ctx->concur_list_len) {
i = ctx->concur_list[ind];
}
if (i == -1) { if (i == -1) {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
continue; continue;
} }
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue; continue;
} }
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@ -573,7 +572,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
@ -594,7 +593,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
@ -615,7 +614,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
@ -634,7 +633,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
[encoder setComputePipelineState:ctx->pipeline_silu]; [encoder setComputePipelineState:ctx->pipeline_silu];
@ -648,7 +647,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
[encoder setComputePipelineState:ctx->pipeline_relu]; [encoder setComputePipelineState:ctx->pipeline_relu];
@ -662,7 +661,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
[encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setComputePipelineState:ctx->pipeline_gelu];
@ -682,7 +681,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const int nth = 32; const int nth = 32;
@ -700,7 +699,7 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const int n_past = ((int32_t *)(dst->op_params))[0]; const int n_past = ((int32_t *)(dst->op_params))[0];
@ -763,7 +762,7 @@ void ggml_metal_graph_compute(
} }
} else { } else {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
int nth0 = 32; int nth0 = 32;
@ -890,7 +889,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
switch (src0->type) { switch (src0->type) {
@ -919,7 +918,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
float eps; float eps;
@ -942,7 +941,7 @@ void ggml_metal_graph_compute(
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const float eps = 1e-5f; const float eps = 1e-5f;
@ -964,7 +963,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
GGML_ASSERT((src0t == GGML_TYPE_F32)); GGML_ASSERT((src0t == GGML_TYPE_F32));
@ -1007,7 +1006,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
@ -1051,7 +1050,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CONT: case GGML_OP_CONT:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
} }
const int nth = 32; const int nth = 32;