fix soft_max backward pass for input->ne[1] != 1

This commit is contained in:
xaedes 2023-05-06 17:30:38 +02:00
parent b4c273f7a3
commit 8cf04fec9d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 38 additions and 22 deletions

59
ggml.c
View file

@ -10721,7 +10721,7 @@ static void ggml_compute_forward_diag_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@ -10737,11 +10737,11 @@ static void ggml_compute_forward_diag_f32(
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
assert(ne00 == ne0);
assert(ne00 == ne1);
assert(ne01 == 1);
assert(ne02 == ne2);
assert(ne03 == ne3);
GGML_ASSERT(ne00 == ne0);
GGML_ASSERT(ne00 == ne1);
GGML_ASSERT(ne01 == 1);
GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne3);
const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1];
@ -10752,8 +10752,8 @@ static void ggml_compute_forward_diag_f32(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
assert(nb00 == sizeof(float));
assert(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = 0; i2 < ne2; i2++) {
@ -13545,23 +13545,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// dx = J * dy
// dxk = sum(Jkj * dyk)
struct ggml_tensor * tensor_t = ggml_cont(ctx,
ggml_permute(ctx,
ggml_reshape(ctx,
tensor,
ggml_new_tensor(ctx,
tensor->type,
4, tensor->ne)),
int64_t ne2[4] = {
tensor->ne[0],
1,
tensor->ne[1]*tensor->ne[2],
tensor->ne[3]
};
struct ggml_tensor * tensor2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * grad2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor->grad),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
tensor2, // [ne0,1,ne1*ne2,ne3]
1, 0, 2, 3));
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_mul_mat(ctx,
ggml_sub(ctx,
ggml_diag(ctx, tensor),
ggml_mul_mat(ctx, tensor_t, tensor_t)),
tensor->grad),
src0->grad, // [ne0,ne1,ne2,ne3]
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor2), // [ne0,1,ne1*ne2,ne3]
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor2_t, // [1,ne0,ne1*ne2,ne3]
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
grad2), // [ne0,1,ne1*ne2,ne3]
src0->grad),
inplace);
}
} break;

View file

@ -859,7 +859,6 @@ int main(int argc, const char ** argv) {
int64_t ne2[4];
get_random_dims(ne2, 4);
ne2[1] = 1;
for (int ndims = 1; ndims <= 3; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);