fix diag_mask to work with non-inplace input
This commit is contained in:
parent
b9920e5c3e
commit
3dbd649cf9
1 changed files with 15 additions and 6 deletions
19
ggml.c
19
ggml.c
|
@ -6215,7 +6215,9 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
||||||
|
((int32_t *) b->data)[0] = n_past;
|
||||||
|
((int32_t *) b->data)[1] = inplace ? 1 : 0;
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_INF;
|
result->op = GGML_OP_DIAG_MASK_INF;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -6254,7 +6256,9 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
||||||
|
((int32_t *) b->data)[0] = n_past;
|
||||||
|
((int32_t *) b->data)[1] = inplace ? 1 : 0;
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_ZERO;
|
result->op = GGML_OP_DIAG_MASK_ZERO;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -10636,13 +10640,18 @@ static void ggml_compute_forward_diag_mask_f32(
|
||||||
const float value) {
|
const float value) {
|
||||||
assert(params->ith == 0);
|
assert(params->ith == 0);
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_I32);
|
||||||
assert(ggml_nelements(src1) == 1);
|
assert(ggml_nelements(src1) == 2);
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = ((int32_t *) src1->data)[0];
|
||||||
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
||||||
|
|
||||||
|
if (!inplace) {
|
||||||
|
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
// TODO: handle transposed/permuted matrices
|
||||||
|
|
||||||
|
@ -13288,7 +13297,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_I32);
|
||||||
assert(ggml_nelements(src1) == 1);
|
assert(ggml_nelements(src1) == 2);
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = ((int32_t *) src1->data)[0];
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_impl(ctx, src0->grad,
|
ggml_add_impl(ctx, src0->grad,
|
||||||
|
@ -13304,7 +13313,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_I32);
|
||||||
assert(ggml_nelements(src1) == 1);
|
assert(ggml_nelements(src1) == 2);
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = ((int32_t *) src1->data)[0];
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_impl(ctx, src0->grad,
|
ggml_add_impl(ctx, src0->grad,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue