diff --git a/ggml.c b/ggml.c index 515a4d19c..985ea7317 100644 --- a/ggml.c +++ b/ggml.c @@ -6215,7 +6215,9 @@ struct ggml_tensor * ggml_diag_mask_inf_impl( } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * b = ggml_new_i32(ctx, n_past); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = inplace ? 1 : 0; result->op = GGML_OP_DIAG_MASK_INF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6254,7 +6256,9 @@ struct ggml_tensor * ggml_diag_mask_zero_impl( } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * b = ggml_new_i32(ctx, n_past); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = inplace ? 1 : 0; result->op = GGML_OP_DIAG_MASK_ZERO; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -10636,13 +10640,18 @@ static void ggml_compute_forward_diag_mask_f32( const float value) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 1); + assert(ggml_nelements(src1) == 2); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) src1->data)[0]; + const bool inplace = (bool)((int32_t *) src1->data)[1]; + + if (!inplace) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + } // TODO: handle transposed/permuted matrices @@ -13288,7 +13297,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 1); + assert(ggml_nelements(src1) == 2); const int n_past = ((int32_t *) src1->data)[0]; src0->grad = ggml_add_impl(ctx, src0->grad, @@ -13304,7 +13313,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 1); + assert(ggml_nelements(src1) == 2); const int n_past = ((int32_t *) src1->data)[0]; src0->grad = ggml_add_impl(ctx, src0->grad,