update ggml_diag_mask to work correctly with automatic inplace
This commit is contained in:
parent
28c6e324d2
commit
5460aec056
1 changed files with 3 additions and 4 deletions
7
ggml.c
7
ggml.c
|
@ -6614,7 +6614,7 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[] = { n_past, inplace ? 1 : 0 };
|
int32_t params[] = { n_past };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_INF;
|
result->op = GGML_OP_DIAG_MASK_INF;
|
||||||
|
@ -6631,7 +6631,6 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
|
return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -6654,7 +6653,7 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[] = { n_past, inplace ? 1 : 0 };
|
int32_t params[] = { n_past };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_ZERO;
|
result->op = GGML_OP_DIAG_MASK_ZERO;
|
||||||
|
@ -11910,7 +11909,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const bool inplace = (bool)((int32_t *) dst->op_params)[1];
|
const bool inplace = src0->data == dst->data;
|
||||||
|
|
||||||
GGML_ASSERT(n_past >= 0);
|
GGML_ASSERT(n_past >= 0);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue