fix soft_max backward pass for input->ne[1] != 1
This commit is contained in:
parent
b4c273f7a3
commit
8cf04fec9d
2 changed files with 38 additions and 22 deletions
59
ggml.c
59
ggml.c
|
@ -10721,7 +10721,7 @@ static void ggml_compute_forward_diag_f32(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
GGML_ASSERT(params->ith == 0);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
|
@ -10737,11 +10737,11 @@ static void ggml_compute_forward_diag_f32(
|
|||
const int ne1 = dst->ne[1];
|
||||
const int ne2 = dst->ne[2];
|
||||
const int ne3 = dst->ne[3];
|
||||
assert(ne00 == ne0);
|
||||
assert(ne00 == ne1);
|
||||
assert(ne01 == 1);
|
||||
assert(ne02 == ne2);
|
||||
assert(ne03 == ne3);
|
||||
GGML_ASSERT(ne00 == ne0);
|
||||
GGML_ASSERT(ne00 == ne1);
|
||||
GGML_ASSERT(ne01 == 1);
|
||||
GGML_ASSERT(ne02 == ne2);
|
||||
GGML_ASSERT(ne03 == ne3);
|
||||
|
||||
const int nb00 = src0->nb[0];
|
||||
const int nb01 = src0->nb[1];
|
||||
|
@ -10752,8 +10752,8 @@ static void ggml_compute_forward_diag_f32(
|
|||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
|
||||
assert(nb00 == sizeof(float));
|
||||
assert(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = 0; i2 < ne2; i2++) {
|
||||
|
@ -13545,23 +13545,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// dx = J * dy
|
||||
// dxk = sum(Jkj * dyk)
|
||||
|
||||
struct ggml_tensor * tensor_t = ggml_cont(ctx,
|
||||
ggml_permute(ctx,
|
||||
ggml_reshape(ctx,
|
||||
tensor,
|
||||
ggml_new_tensor(ctx,
|
||||
tensor->type,
|
||||
4, tensor->ne)),
|
||||
int64_t ne2[4] = {
|
||||
tensor->ne[0],
|
||||
1,
|
||||
tensor->ne[1]*tensor->ne[2],
|
||||
tensor->ne[3]
|
||||
};
|
||||
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
||||
ggml_reshape_4d(ctx,
|
||||
ggml_cont(ctx, tensor),
|
||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||
|
||||
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
||||
ggml_reshape_4d(ctx,
|
||||
ggml_cont(ctx, tensor->grad),
|
||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||
|
||||
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||
tensor2, // [ne0,1,ne1*ne2,ne3]
|
||||
1, 0, 2, 3));
|
||||
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_sub(ctx,
|
||||
ggml_diag(ctx, tensor),
|
||||
ggml_mul_mat(ctx, tensor_t, tensor_t)),
|
||||
tensor->grad),
|
||||
src0->grad, // [ne0,ne1,ne2,ne3]
|
||||
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
|
||||
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
||||
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
tensor2), // [ne0,1,ne1*ne2,ne3]
|
||||
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
||||
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
||||
grad2), // [ne0,1,ne1*ne2,ne3]
|
||||
src0->grad),
|
||||
inplace);
|
||||
}
|
||||
} break;
|
||||
|
|
|
@ -859,7 +859,6 @@ int main(int argc, const char ** argv) {
|
|||
|
||||
int64_t ne2[4];
|
||||
get_random_dims(ne2, 4);
|
||||
ne2[1] = 1;
|
||||
|
||||
for (int ndims = 1; ndims <= 3; ++ndims) {
|
||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue