successfully test permute backward

This commit is contained in:
xaedes 2023-04-28 17:47:23 +02:00
parent 86b44a02e4
commit a7a837047c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -256,6 +256,8 @@ bool check_mat_mul(
return true;
}
#define NUM_PERMUTATIONS (4*3*2*1)
int main(int argc, const char ** argv) {
struct ggml_init_params params = {
.mem_size = 128*1024*1024,
@ -265,6 +267,32 @@ int main(int argc, const char ** argv) {
int64_t ne[4];
int all_permutations[4 * NUM_PERMUTATIONS];
{
int count = 0;
for (int ax0=0; ax0<4; ++ax0) {
for (int ax1=0; ax1<4; ++ax1) {
if (ax1 == ax0) continue;
for (int ax2=0; ax2<4; ++ax2) {
if (ax2 == ax0) continue;
if (ax2 == ax1) continue;
for (int ax3=0; ax3<4; ++ax3) {
if (ax3 == ax0) continue;
if (ax3 == ax1) continue;
if (ax3 == ax2) continue;
assert(count < NUM_PERMUTATIONS);
all_permutations[count*4+0] = ax0;
all_permutations[count*4+1] = ax1;
all_permutations[count*4+2] = ax2;
all_permutations[count*4+3] = ax3;
++count;
}
}
}
}
}
// original loop: 1000
int niter = 1000;
const char *env = getenv("GGML_NLOOP");
@ -565,6 +593,39 @@ int main(int argc, const char ** argv) {
}
}
// permute
{
int64_t ne2[4];
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims)
{
// ggml_permute will set axes of dimensions below n_dims to 1.
// to make ggml_permute correctly work on all axes,
// the input tensor needs maximal n_dim of 4.
for (int i=0; i<ndims; ++i) {
ne2[i] = ne[i];
}
for (int i=ndims; i<4; ++i) {
ne2[i] = 1;
}
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
// sum requires contiguous tensor rows, so we only test the permutations where ax0 == 0 --> NUM_PERMUTATIONS/4.
// when the logic for gradients work for these permutations, they should also work for the others.
const int p = irand(NUM_PERMUTATIONS/4);
const int ax0 = all_permutations[p*4+0];
const int ax1 = all_permutations[p*4+1];
const int ax2 = all_permutations[p*4+2];
const int ax3 = all_permutations[p*4+3];
struct ggml_tensor * f = ggml_sum(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3));
check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
}
}
// softmax
{
const int nargs = 1;