metal : code style changes
This commit is contained in:
parent
ea02f675f9
commit
141d88d916
1 changed files with 29 additions and 30 deletions
57
ggml-metal.m
57
ggml-metal.m
|
@ -451,13 +451,13 @@ void ggml_metal_graph_compute(
|
|||
|
||||
// if there is ctx->concur_list, dispatch concurrently
|
||||
// else fallback to serial dispatch
|
||||
MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor;
|
||||
encoder_descriptor.dispatchType = MTLDispatchTypeSerial;
|
||||
int all_nodes_len = gf->n_nodes;
|
||||
if (ctx->concur_list_len) {
|
||||
encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent;
|
||||
all_nodes_len = ctx->concur_list_len;
|
||||
}
|
||||
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||
|
||||
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
||||
|
||||
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||
|
||||
// create multiple command buffers and enqueue them
|
||||
// then, we encode the graph into the command buffers in parallel
|
||||
|
||||
|
@ -476,7 +476,7 @@ void ggml_metal_graph_compute(
|
|||
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb;
|
||||
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
||||
|
||||
dispatch_async(queue, ^{
|
||||
size_t offs_src0 = 0;
|
||||
|
@ -488,21 +488,20 @@ void ggml_metal_graph_compute(
|
|||
id<MTLComputeCommandEncoder> encoder = nil;
|
||||
|
||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||
const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb;
|
||||
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||
|
||||
for (int ind = node_start; ind < node_end; ++ind) {
|
||||
int i = ind;
|
||||
if (ctx->concur_list_len) {
|
||||
i = ctx->concur_list[ind];
|
||||
}
|
||||
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||
|
||||
if (i == -1) {
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
continue;
|
||||
}
|
||||
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||
continue;
|
||||
}
|
||||
|
||||
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
|
@ -573,7 +572,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_ADD:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
|
@ -594,7 +593,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_MUL:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
|
@ -615,7 +614,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_SCALE:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const float scale = *(const float *) src1->data;
|
||||
|
@ -634,7 +633,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
|
@ -648,7 +647,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||
|
@ -662,7 +661,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||
|
@ -682,7 +681,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int nth = 32;
|
||||
|
@ -700,7 +699,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||
|
@ -763,7 +762,7 @@ void ggml_metal_graph_compute(
|
|||
}
|
||||
} else {
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
int nth0 = 32;
|
||||
|
@ -890,7 +889,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
|
@ -919,7 +918,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
float eps;
|
||||
|
@ -942,7 +941,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_NORM:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const float eps = 1e-5f;
|
||||
|
@ -964,7 +963,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_ALIBI:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||
|
@ -1007,7 +1006,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_ROPE:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
|
@ -1051,7 +1050,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_CONT:
|
||||
{
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||
}
|
||||
|
||||
const int nth = 32;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue