improve performance of sqr backward pass

use scale(x,y) instead of mul(x,repeat(y,x))
This commit is contained in:
xaedes 2023-04-26 00:46:20 +02:00
parent bfe507213c
commit b583136cfa
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

8
ggml.c
View file

@ -12616,9 +12616,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_mul(ctx,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
ggml_new_f32(ctx, 2.0f)),
inplace);
}
} break;
@ -12965,7 +12965,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
src0->grad = ggml_sub_impl(ctx,
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
@ -12986,7 +12986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
src0->grad = ggml_sub_impl(ctx,
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope(ctx,
tensor->grad,