fix soft_max backward pass for input->ne[1] != 1

This commit is contained in:
xaedes 2023-05-06 17:30:38 +02:00
parent b4c273f7a3
commit 8cf04fec9d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 38 additions and 22 deletions

59
ggml.c
View file

@ -10721,7 +10721,7 @@ static void ggml_compute_forward_diag_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(params->ith == 0); GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -10737,11 +10737,11 @@ static void ggml_compute_forward_diag_f32(
const int ne1 = dst->ne[1]; const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2]; const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3]; const int ne3 = dst->ne[3];
assert(ne00 == ne0); GGML_ASSERT(ne00 == ne0);
assert(ne00 == ne1); GGML_ASSERT(ne00 == ne1);
assert(ne01 == 1); GGML_ASSERT(ne01 == 1);
assert(ne02 == ne2); GGML_ASSERT(ne02 == ne2);
assert(ne03 == ne3); GGML_ASSERT(ne03 == ne3);
const int nb00 = src0->nb[0]; const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1]; const int nb01 = src0->nb[1];
@ -10752,8 +10752,8 @@ static void ggml_compute_forward_diag_f32(
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
assert(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
assert(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
for (int i3 = 0; i3 < ne3; i3++) { for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = 0; i2 < ne2; i2++) { for (int i2 = 0; i2 < ne2; i2++) {
@ -13545,23 +13545,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// dx = J * dy // dx = J * dy
// dxk = sum(Jkj * dyk) // dxk = sum(Jkj * dyk)
struct ggml_tensor * tensor_t = ggml_cont(ctx, int64_t ne2[4] = {
ggml_permute(ctx, tensor->ne[0],
ggml_reshape(ctx, 1,
tensor, tensor->ne[1]*tensor->ne[2],
ggml_new_tensor(ctx, tensor->ne[3]
tensor->type, };
4, tensor->ne)), struct ggml_tensor * tensor2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * grad2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor->grad),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
tensor2, // [ne0,1,ne1*ne2,ne3]
1, 0, 2, 3)); 1, 0, 2, 3));
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_impl(ctx,
src0->grad, src0->grad, // [ne0,ne1,ne2,ne3]
ggml_mul_mat(ctx, ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
ggml_sub(ctx, ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
ggml_diag(ctx, tensor), ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
ggml_mul_mat(ctx, tensor_t, tensor_t)), ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor->grad), tensor2), // [ne0,1,ne1*ne2,ne3]
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor2_t, // [1,ne0,ne1*ne2,ne3]
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
grad2), // [ne0,1,ne1*ne2,ne3]
src0->grad),
inplace); inplace);
} }
} break; } break;

View file

@ -859,7 +859,6 @@ int main(int argc, const char ** argv) {
int64_t ne2[4]; int64_t ne2[4];
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
ne2[1] = 1;
for (int ndims = 1; ndims <= 3; ++ndims) { for (int ndims = 1; ndims <= 3; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);