fix diag_mask to work with non-inplace input
This commit is contained in:
parent
b9920e5c3e
commit
3dbd649cf9
1 changed files with 15 additions and 6 deletions
21
ggml.c
21
ggml.c
|
@ -6215,7 +6215,9 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
|
|||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
((int32_t *) b->data)[1] = inplace ? 1 : 0;
|
||||
|
||||
result->op = GGML_OP_DIAG_MASK_INF;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -6254,7 +6256,9 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
|
|||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
((int32_t *) b->data)[1] = inplace ? 1 : 0;
|
||||
|
||||
result->op = GGML_OP_DIAG_MASK_ZERO;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -10636,13 +10640,18 @@ static void ggml_compute_forward_diag_mask_f32(
|
|||
const float value) {
|
||||
assert(params->ith == 0);
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 1);
|
||||
assert(ggml_nelements(src1) == 2);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
||||
|
||||
if (!inplace) {
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
}
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
|
||||
|
@ -13288,7 +13297,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 1);
|
||||
assert(ggml_nelements(src1) == 2);
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
|
@ -13304,7 +13313,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 1);
|
||||
assert(ggml_nelements(src1) == 2);
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue