diff --git a/ggml.c b/ggml.c index cd4f54bf1..9b8dcc196 100644 --- a/ggml.c +++ b/ggml.c @@ -10721,7 +10721,7 @@ static void ggml_compute_forward_diag_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - assert(params->ith == 0); + GGML_ASSERT(params->ith == 0); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -10737,11 +10737,11 @@ static void ggml_compute_forward_diag_f32( const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; const int ne3 = dst->ne[3]; - assert(ne00 == ne0); - assert(ne00 == ne1); - assert(ne01 == 1); - assert(ne02 == ne2); - assert(ne03 == ne3); + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne00 == ne1); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); const int nb00 = src0->nb[0]; const int nb01 = src0->nb[1]; @@ -10752,8 +10752,8 @@ static void ggml_compute_forward_diag_f32( const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - assert(nb00 == sizeof(float)); - assert(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); for (int i3 = 0; i3 < ne3; i3++) { for (int i2 = 0; i2 < ne2; i2++) { @@ -13545,23 +13545,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // dx = J * dy // dxk = sum(Jkj * dyk) - struct ggml_tensor * tensor_t = ggml_cont(ctx, - ggml_permute(ctx, - ggml_reshape(ctx, - tensor, - ggml_new_tensor(ctx, - tensor->type, - 4, tensor->ne)), + int64_t ne2[4] = { + tensor->ne[0], + 1, + tensor->ne[1]*tensor->ne[2], + tensor->ne[3] + }; + struct ggml_tensor * tensor2 = ggml_cont(ctx, + ggml_reshape_4d(ctx, + ggml_cont(ctx, tensor), + ne2[0], ne2[1], ne2[2], ne2[3])); + + struct ggml_tensor * grad2 = ggml_cont(ctx, + ggml_reshape_4d(ctx, + ggml_cont(ctx, tensor->grad), + ne2[0], ne2[1], ne2[2], ne2[3])); + + struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3] + ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3] + tensor2, // [ne0,1,ne1*ne2,ne3] 1, 0, 2, 3)); src0->grad = ggml_add_impl(ctx, - src0->grad, - ggml_mul_mat(ctx, - ggml_sub(ctx, - ggml_diag(ctx, tensor), - ggml_mul_mat(ctx, tensor_t, tensor_t)), - tensor->grad), + src0->grad, // [ne0,ne1,ne2,ne3] + ggml_reshape(ctx, // [ne0,ne1,ne2,ne3] + ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3] + ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3] + ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3] + tensor2), // [ne0,1,ne1*ne2,ne3] + ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3] + tensor2_t, // [1,ne0,ne1*ne2,ne3] + tensor2_t)), // [1,ne0,ne1*ne2,ne3] + grad2), // [ne0,1,ne1*ne2,ne3] + src0->grad), inplace); } } break; diff --git a/tests/test-grad0.c b/tests/test-grad0.c index aa8d7a97f..b4ae4e788 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -859,7 +859,6 @@ int main(int argc, const char ** argv) { int64_t ne2[4]; get_random_dims(ne2, 4); - ne2[1] = 1; for (int ndims = 1; ndims <= 3; ++ndims) { x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);