fix soft_max backward pass for input->ne[1] != 1
This commit is contained in:
parent
b4c273f7a3
commit
8cf04fec9d
2 changed files with 38 additions and 22 deletions
59
ggml.c
59
ggml.c
|
@ -10721,7 +10721,7 @@ static void ggml_compute_forward_diag_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
assert(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
|
@ -10737,11 +10737,11 @@ static void ggml_compute_forward_diag_f32(
|
||||||
const int ne1 = dst->ne[1];
|
const int ne1 = dst->ne[1];
|
||||||
const int ne2 = dst->ne[2];
|
const int ne2 = dst->ne[2];
|
||||||
const int ne3 = dst->ne[3];
|
const int ne3 = dst->ne[3];
|
||||||
assert(ne00 == ne0);
|
GGML_ASSERT(ne00 == ne0);
|
||||||
assert(ne00 == ne1);
|
GGML_ASSERT(ne00 == ne1);
|
||||||
assert(ne01 == 1);
|
GGML_ASSERT(ne01 == 1);
|
||||||
assert(ne02 == ne2);
|
GGML_ASSERT(ne02 == ne2);
|
||||||
assert(ne03 == ne3);
|
GGML_ASSERT(ne03 == ne3);
|
||||||
|
|
||||||
const int nb00 = src0->nb[0];
|
const int nb00 = src0->nb[0];
|
||||||
const int nb01 = src0->nb[1];
|
const int nb01 = src0->nb[1];
|
||||||
|
@ -10752,8 +10752,8 @@ static void ggml_compute_forward_diag_f32(
|
||||||
const int nb2 = dst->nb[2];
|
const int nb2 = dst->nb[2];
|
||||||
const int nb3 = dst->nb[3];
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
assert(nb00 == sizeof(float));
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
assert(nb0 == sizeof(float));
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
|
||||||
for (int i3 = 0; i3 < ne3; i3++) {
|
for (int i3 = 0; i3 < ne3; i3++) {
|
||||||
for (int i2 = 0; i2 < ne2; i2++) {
|
for (int i2 = 0; i2 < ne2; i2++) {
|
||||||
|
@ -13545,23 +13545,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
// dx = J * dy
|
// dx = J * dy
|
||||||
// dxk = sum(Jkj * dyk)
|
// dxk = sum(Jkj * dyk)
|
||||||
|
|
||||||
struct ggml_tensor * tensor_t = ggml_cont(ctx,
|
int64_t ne2[4] = {
|
||||||
ggml_permute(ctx,
|
tensor->ne[0],
|
||||||
ggml_reshape(ctx,
|
1,
|
||||||
tensor,
|
tensor->ne[1]*tensor->ne[2],
|
||||||
ggml_new_tensor(ctx,
|
tensor->ne[3]
|
||||||
tensor->type,
|
};
|
||||||
4, tensor->ne)),
|
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
||||||
|
ggml_reshape_4d(ctx,
|
||||||
|
ggml_cont(ctx, tensor),
|
||||||
|
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||||
|
|
||||||
|
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
||||||
|
ggml_reshape_4d(ctx,
|
||||||
|
ggml_cont(ctx, tensor->grad),
|
||||||
|
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||||
|
|
||||||
|
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||||
|
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||||
|
tensor2, // [ne0,1,ne1*ne2,ne3]
|
||||||
1, 0, 2, 3));
|
1, 0, 2, 3));
|
||||||
|
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_impl(ctx,
|
ggml_add_impl(ctx,
|
||||||
src0->grad,
|
src0->grad, // [ne0,ne1,ne2,ne3]
|
||||||
ggml_mul_mat(ctx,
|
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
|
||||||
ggml_sub(ctx,
|
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
||||||
ggml_diag(ctx, tensor),
|
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||||
ggml_mul_mat(ctx, tensor_t, tensor_t)),
|
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||||
tensor->grad),
|
tensor2), // [ne0,1,ne1*ne2,ne3]
|
||||||
|
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||||
|
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
||||||
|
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
||||||
|
grad2), // [ne0,1,ne1*ne2,ne3]
|
||||||
|
src0->grad),
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -859,7 +859,6 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
int64_t ne2[4];
|
int64_t ne2[4];
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
ne2[1] = 1;
|
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 3; ++ndims) {
|
for (int ndims = 1; ndims <= 3; ++ndims) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue