diff --git a/ggml.c b/ggml.c index 1b792aa22..ade83d5ac 100644 --- a/ggml.c +++ b/ggml.c @@ -6614,7 +6614,7 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[] = { n_past, inplace ? 1 : 0 }; + int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_DIAG_MASK_INF; @@ -6631,7 +6631,6 @@ struct ggml_tensor * ggml_diag_mask_inf( return ggml_diag_mask_inf_impl(ctx, a, n_past, false); } - struct ggml_tensor * ggml_diag_mask_inf_inplace( struct ggml_context * ctx, struct ggml_tensor * a, @@ -6654,7 +6653,7 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[] = { n_past, inplace ? 1 : 0 }; + int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_DIAG_MASK_ZERO; @@ -11910,7 +11909,7 @@ static void ggml_compute_forward_diag_mask_f32( const int nth = params->nth; const int n_past = ((int32_t *) dst->op_params)[0]; - const bool inplace = (bool)((int32_t *) dst->op_params)[1]; + const bool inplace = src0->data == dst->data; GGML_ASSERT(n_past >= 0);