metal : code style changes

This commit is contained in:
Georgi Gerganov 2023-07-25 14:59:17 +03:00
parent ea02f675f9
commit 141d88d916
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -451,13 +451,13 @@ void ggml_metal_graph_compute(
// if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch
MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor;
encoder_descriptor.dispatchType = MTLDispatchTypeSerial;
int all_nodes_len = gf->n_nodes;
if (ctx->concur_list_len) {
encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent;
all_nodes_len = ctx->concur_list_len;
}
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
@ -476,7 +476,7 @@ void ggml_metal_graph_compute(
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_async(queue, ^{
size_t offs_src0 = 0;
@ -487,22 +487,21 @@ void ggml_metal_graph_compute(
id<MTLComputeCommandEncoder> encoder = nil;
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb;
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
for (int ind = node_start; ind < node_end; ++ind) {
int i = ind;
if (ctx->concur_list_len) {
i = ctx->concur_list[ind];
}
const int i = has_concur ? ctx->concur_list[ind] : ind;
if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@ -573,7 +572,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
@ -594,7 +593,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
@ -615,7 +614,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float scale = *(const float *) src1->data;
@ -634,7 +633,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_silu];
@ -648,7 +647,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_relu];
@ -662,7 +661,7 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_gelu];
@ -682,7 +681,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
@ -700,7 +699,7 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *)(dst->op_params))[0];
@ -763,7 +762,7 @@ void ggml_metal_graph_compute(
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
int nth0 = 32;
@ -890,7 +889,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
switch (src0->type) {
@ -919,7 +918,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
float eps;
@ -942,7 +941,7 @@ void ggml_metal_graph_compute(
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float eps = 1e-5f;
@ -964,7 +963,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
GGML_ASSERT((src0t == GGML_TYPE_F32));
@ -1007,7 +1006,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *) dst->op_params)[0];
@ -1051,7 +1050,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;