diff --git a/ggml-metal.m b/ggml-metal.m index 2cc6dfab2..cadbfe04a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -838,8 +838,6 @@ void ggml_metal_graph_compute( case GGML_OP_DIAG_MASK_INF: { const int n_past = ((int32_t *)(dst->op_params))[0]; - const int64_t n00x01 = ne00*ne01; - assert((n00x01*ne02)%8 == 0); if (ne00%8 == 0) { [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; @@ -848,8 +846,8 @@ void ggml_metal_graph_compute( } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&n00x01 length:sizeof(n00x01) atIndex:3]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; if (ne00%8 == 0) { diff --git a/ggml-metal.metal b/ggml-metal.metal index cf8f2e610..ff289908a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -183,13 +183,15 @@ kernel void kernel_soft_max_4( } kernel void kernel_diag_mask_inf( - device const float4 * src0, - device float4 * dst, + device const float * src0, + device float * dst, constant int64_t & ne00, - constant int64_t & n00x01, + constant int64_t & ne01, constant int & n_past, uint3 tpig[[thread_position_in_grid]]) { - const int64_t i = 2*tpig[0]; + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; if (i00 > n_past + i01) { dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;