diff --git a/ggml-metal.m b/ggml-metal.m index f9bd69dc8..24d3989bc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -964,9 +964,9 @@ void ggml_metal_graph_compute( const int64_t nb = ne00; [encoder setComputePipelineState:ctx->pipeline_concat]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; @@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false); } } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; @@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute( [encoder setComputePipelineState:ctx->pipeline_scale]; } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; @@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_SILU: { [encoder setComputePipelineState:ctx->pipeline_silu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); GGML_ASSERT(n % 4 == 0); @@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_RELU: { [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); @@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute( case GGML_UNARY_OP_GELU: { [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); GGML_ASSERT(n % 4 == 0); @@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ggml_is_contiguous(src0)); [encoder setComputePipelineState:ctx->pipeline_sqr]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; @@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); [encoder setComputePipelineState:ctx->pipeline_sum_rows]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; @@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute( const float scale = ((float *) dst->op_params)[0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; @@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute( } else { [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; @@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; @@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute( } }; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; @@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; @@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; @@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute( } [encoder setComputePipelineState:ctx->pipeline_rms_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; @@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute( const int nth = MIN(256, ne00); [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; @@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute( const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; @@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false); }; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; @@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false); }; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; @@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false); }; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; @@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];