successfully test permute backward
This commit is contained in:
parent
86b44a02e4
commit
a7a837047c
1 changed files with 61 additions and 0 deletions
|
@ -256,6 +256,8 @@ bool check_mat_mul(
|
|||
return true;
|
||||
}
|
||||
|
||||
#define NUM_PERMUTATIONS (4*3*2*1)
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = 128*1024*1024,
|
||||
|
@ -265,6 +267,32 @@ int main(int argc, const char ** argv) {
|
|||
|
||||
int64_t ne[4];
|
||||
|
||||
int all_permutations[4 * NUM_PERMUTATIONS];
|
||||
{
|
||||
int count = 0;
|
||||
for (int ax0=0; ax0<4; ++ax0) {
|
||||
for (int ax1=0; ax1<4; ++ax1) {
|
||||
if (ax1 == ax0) continue;
|
||||
for (int ax2=0; ax2<4; ++ax2) {
|
||||
if (ax2 == ax0) continue;
|
||||
if (ax2 == ax1) continue;
|
||||
for (int ax3=0; ax3<4; ++ax3) {
|
||||
if (ax3 == ax0) continue;
|
||||
if (ax3 == ax1) continue;
|
||||
if (ax3 == ax2) continue;
|
||||
assert(count < NUM_PERMUTATIONS);
|
||||
all_permutations[count*4+0] = ax0;
|
||||
all_permutations[count*4+1] = ax1;
|
||||
all_permutations[count*4+2] = ax2;
|
||||
all_permutations[count*4+3] = ax3;
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// original loop: 1000
|
||||
int niter = 1000;
|
||||
const char *env = getenv("GGML_NLOOP");
|
||||
|
@ -565,6 +593,39 @@ int main(int argc, const char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// permute
|
||||
{
|
||||
int64_t ne2[4];
|
||||
|
||||
const int nargs = 1;
|
||||
for (int ndims = 1; ndims <= 4; ++ndims)
|
||||
{
|
||||
// ggml_permute will set axes of dimensions below n_dims to 1.
|
||||
// to make ggml_permute correctly work on all axes,
|
||||
// the input tensor needs maximal n_dim of 4.
|
||||
for (int i=0; i<ndims; ++i) {
|
||||
ne2[i] = ne[i];
|
||||
}
|
||||
for (int i=ndims; i<4; ++i) {
|
||||
ne2[i] = 1;
|
||||
}
|
||||
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
|
||||
// sum requires contiguous tensor rows, so we only test the permutations where ax0 == 0 --> NUM_PERMUTATIONS/4.
|
||||
// when the logic for gradients work for these permutations, they should also work for the others.
|
||||
const int p = irand(NUM_PERMUTATIONS/4);
|
||||
const int ax0 = all_permutations[p*4+0];
|
||||
const int ax1 = all_permutations[p*4+1];
|
||||
const int ax2 = all_permutations[p*4+2];
|
||||
const int ax3 = all_permutations[p*4+3];
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3));
|
||||
|
||||
check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
// softmax
|
||||
{
|
||||
const int nargs = 1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue