test broadcasting mul_mat backward pass
This commit is contained in:
parent
aea8b6be74
commit
dd3278619d
1 changed files with 27 additions and 17 deletions
|
@ -749,25 +749,35 @@ int main(int argc, const char ** argv) {
|
||||||
{
|
{
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
|
|
||||||
for (int ndims = 2; ndims <= 2; ++ndims) {
|
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
|
int max_nrep = (ndims >= 3) ? 2 : 1;
|
||||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
{
|
for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
|
||||||
int64_t ne2[4];
|
for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
|
||||||
get_random_dims(ne2, 4);
|
{
|
||||||
ne2[0] = ne[0];
|
int64_t ne2[4];
|
||||||
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
get_random_dims(ne2, 4);
|
||||||
|
ne2[0] = ne[0];
|
||||||
|
ne2[2] = nrep2 * ne[2];
|
||||||
|
ne2[3] = nrep3 * ne[3];
|
||||||
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
|
struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, m);
|
||||||
|
|
||||||
|
GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
|
||||||
|
|
||||||
|
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
if (ndims == 2) {
|
||||||
|
// check_mat_mul does not support ndims > 2
|
||||||
|
check_mat_mul(m, x[1], x[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
|
||||||
ggml_set_param(ctx0, x[1]);
|
|
||||||
|
|
||||||
struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, m);
|
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
|
|
||||||
|
|
||||||
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
|
||||||
check_mat_mul(m, x[1], x[0]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue