test broadcasting mul_mat backward pass

This commit is contained in:
xaedes 2023-09-09 18:38:29 +02:00
parent aea8b6be74
commit dd3278619d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -749,12 +749,17 @@ int main(int argc, const char ** argv) {
{ {
const int nargs = 2; const int nargs = 2;
for (int ndims = 2; ndims <= 2; ++ndims) { for (int ndims = 2; ndims <= 4; ++ndims) {
int max_nrep = (ndims >= 3) ? 2 : 1;
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
{ {
int64_t ne2[4]; int64_t ne2[4];
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
ne2[0] = ne[0]; ne2[0] = ne[0];
ne2[2] = nrep2 * ne[2];
ne2[3] = nrep3 * ne[3];
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
} }
@ -767,9 +772,14 @@ int main(int argc, const char ** argv) {
GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims); GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
if (ndims == 2) {
// check_mat_mul does not support ndims > 2
check_mat_mul(m, x[1], x[0]); check_mat_mul(m, x[1], x[0]);
} }
} }
}
}
}
// elu, not yet fully implemented // elu, not yet fully implemented
if(0) if(0)