Fix Metal API validation errors
This commit is contained in:
parent
fe680e3d10
commit
bc38194ef4
1 changed files with 50 additions and 50 deletions
100
ggml-metal.m
100
ggml-metal.m
|
@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
|
||||||
const int64_t nb = ne00;
|
const int64_t nb = ne00;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_concat];
|
[encoder setComputePipelineState:ctx->pipeline_concat];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
GGML_ASSERT(n % 4 == 0);
|
GGML_ASSERT(n % 4 == 0);
|
||||||
|
@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
GGML_ASSERT(n % 4 == 0);
|
GGML_ASSERT(n % 4 == 0);
|
||||||
|
@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
|
@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
const float scale = ((float *) dst->op_params)[0];
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||||
|
@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
||||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||||
|
@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
||||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
||||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
||||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
||||||
|
@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||||
|
@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
|
@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
|
||||||
const int nth = MIN(256, ne00);
|
const int nth = MIN(256, ne00);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_norm];
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
|
@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||||
|
@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
||||||
|
@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||||
|
@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
||||||
|
@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue