metal : code style changes
This commit is contained in:
parent
ea02f675f9
commit
141d88d916
1 changed files with 29 additions and 30 deletions
57
ggml-metal.m
57
ggml-metal.m
|
@ -451,13 +451,13 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
// if there is ctx->concur_list, dispatch concurrently
|
// if there is ctx->concur_list, dispatch concurrently
|
||||||
// else fallback to serial dispatch
|
// else fallback to serial dispatch
|
||||||
MTLComputePassDescriptor * encoder_descriptor = MTLComputePassDescriptor.computePassDescriptor;
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||||
encoder_descriptor.dispatchType = MTLDispatchTypeSerial;
|
|
||||||
int all_nodes_len = gf->n_nodes;
|
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
||||||
if (ctx->concur_list_len) {
|
|
||||||
encoder_descriptor.dispatchType = MTLDispatchTypeConcurrent;
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||||
all_nodes_len = ctx->concur_list_len;
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||||
}
|
|
||||||
// create multiple command buffers and enqueue them
|
// create multiple command buffers and enqueue them
|
||||||
// then, we encode the graph into the command buffers in parallel
|
// then, we encode the graph into the command buffers in parallel
|
||||||
|
|
||||||
|
@ -476,7 +476,7 @@ void ggml_metal_graph_compute(
|
||||||
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
||||||
|
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
const int n_nodes_per_cb = (all_nodes_len + n_cb - 1) / n_cb;
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
||||||
|
|
||||||
dispatch_async(queue, ^{
|
dispatch_async(queue, ^{
|
||||||
size_t offs_src0 = 0;
|
size_t offs_src0 = 0;
|
||||||
|
@ -488,21 +488,20 @@ void ggml_metal_graph_compute(
|
||||||
id<MTLComputeCommandEncoder> encoder = nil;
|
id<MTLComputeCommandEncoder> encoder = nil;
|
||||||
|
|
||||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||||
const int node_end = (cb_idx == n_cb - 1) ? all_nodes_len : (cb_idx + 1) * n_nodes_per_cb;
|
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||||
|
|
||||||
for (int ind = node_start; ind < node_end; ++ind) {
|
for (int ind = node_start; ind < node_end; ++ind) {
|
||||||
int i = ind;
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||||
if (ctx->concur_list_len) {
|
|
||||||
i = ctx->concur_list[ind];
|
|
||||||
}
|
|
||||||
if (i == -1) {
|
if (i == -1) {
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
|
||||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||||
|
@ -573,7 +572,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
|
@ -594,7 +593,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
|
@ -615,7 +614,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
|
@ -634,7 +633,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
|
@ -648,7 +647,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
|
@ -662,7 +661,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||||
|
@ -682,7 +681,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
@ -700,7 +699,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
@ -763,7 +762,7 @@ void ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
|
@ -890,7 +889,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
@ -919,7 +918,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
|
@ -942,7 +941,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float eps = 1e-5f;
|
const float eps = 1e-5f;
|
||||||
|
@ -964,7 +963,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||||
|
@ -1007,7 +1006,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
|
@ -1051,7 +1050,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue