fix diag_mask to work with non-inplace input

This commit is contained in:
xaedes 2023-04-28 20:03:56 +02:00
parent b9920e5c3e
commit 3dbd649cf9
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

21
ggml.c
View file

@ -6215,7 +6215,9 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = inplace ? 1 : 0;
result->op = GGML_OP_DIAG_MASK_INF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -6254,7 +6256,9 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = inplace ? 1 : 0;
result->op = GGML_OP_DIAG_MASK_ZERO;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -10636,13 +10640,18 @@ static void ggml_compute_forward_diag_mask_f32(
const float value) {
assert(params->ith == 0);
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 1);
assert(ggml_nelements(src1) == 2);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n_past = ((int32_t *) src1->data)[0];
const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];
if (!inplace) {
ggml_compute_forward_dup_same_cont(params, src0, dst);
}
// TODO: handle transposed/permuted matrices
@ -13288,7 +13297,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 1);
assert(ggml_nelements(src1) == 2);
const int n_past = ((int32_t *) src1->data)[0];
src0->grad =
ggml_add_impl(ctx, src0->grad,
@ -13304,7 +13313,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 1);
assert(ggml_nelements(src1) == 2);
const int n_past = ((int32_t *) src1->data)[0];
src0->grad =
ggml_add_impl(ctx, src0->grad,