test flash attention backward pass
need to set loose error bounds to pass. the finitie differences are close to numeric limits and often return quite different values than the backward pass. reducing eps further lets the gradients vanish completely. likewise setting eps to big results in wronger values. the softmax in the middle of the function is probably the most responsible for the numeric issues using finite differences.
This commit is contained in:
parent
38560b6d51
commit
70c08318af
1 changed files with 40 additions and 1 deletions
|
@ -5,7 +5,7 @@
|
|||
#include <stdlib.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MAX_NARGS 2
|
||||
#define MAX_NARGS 3
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
|
@ -1143,6 +1143,45 @@ int main(int argc, const char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// flash_attn
|
||||
{
|
||||
const int nargs = 3;
|
||||
|
||||
int64_t ne2[4];
|
||||
|
||||
get_random_dims(ne2, 4);
|
||||
int64_t D = ne2[0];
|
||||
int64_t N = ne2[1];
|
||||
int64_t M = ne2[2] + N;
|
||||
int64_t B = ne2[3];
|
||||
|
||||
for (int masked = 0; masked <= 1; ++masked) {
|
||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||
int64_t neq[4] = { D, N, B, ne[3] };
|
||||
int64_t nek[4] = { D, M, B, ne[3] };
|
||||
int64_t nev[4] = { M, D, B, ne[3] };
|
||||
if (ndims == 2) {
|
||||
neq[2] = 1; neq[3] = 1;
|
||||
nek[2] = 1; nek[3] = 1;
|
||||
nev[2] = 1; nev[3] = 1;
|
||||
} else if (ndims == 3) {
|
||||
neq[3] = 1;
|
||||
nek[3] = 1;
|
||||
nev[3] = 1;
|
||||
}
|
||||
x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||
x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||
x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
ggml_set_param(ctx0, x[1]);
|
||||
ggml_set_param(ctx0, x[2]);
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue