Fix Metal API validation errors

This commit is contained in:
Finn Voorhees 2023-12-08 11:44:08 +00:00 committed by GitHub
parent fe680e3d10
commit bc38194ef4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
const int64_t nb = ne00;
[encoder setComputePipelineState:ctx->pipeline_concat];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
}
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_scale];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_SILU:
{
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_RELU:
{
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_GELU:
{
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));
[encoder setComputePipelineState:ctx->pipeline_sqr];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
} else {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
}
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
}
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
const int nth = MIN(256, ne00);
[encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];