From 70c08318af062c31e47fd6a914e9f3abf8db385e Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 29 May 2023 23:51:40 +0200 Subject: [PATCH] test flash attention backward pass need to set loose error bounds to pass. the finitie differences are close to numeric limits and often return quite different values than the backward pass. reducing eps further lets the gradients vanish completely. likewise setting eps to big results in wronger values. the softmax in the middle of the function is probably the most responsible for the numeric issues using finite differences. --- tests/test-grad0.c | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test-grad0.c b/tests/test-grad0.c index b7d68cad9..c8c2c0f71 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -5,7 +5,7 @@ #include #include -#define MAX_NARGS 2 +#define MAX_NARGS 3 #undef MIN #undef MAX @@ -1143,6 +1143,45 @@ int main(int argc, const char ** argv) { } } + // flash_attn + { + const int nargs = 3; + + int64_t ne2[4]; + + get_random_dims(ne2, 4); + int64_t D = ne2[0]; + int64_t N = ne2[1]; + int64_t M = ne2[2] + N; + int64_t B = ne2[3]; + + for (int masked = 0; masked <= 1; ++masked) { + for (int ndims = 2; ndims <= 4; ++ndims) { + int64_t neq[4] = { D, N, B, ne[3] }; + int64_t nek[4] = { D, M, B, ne[3] }; + int64_t nev[4] = { M, D, B, ne[3] }; + if (ndims == 2) { + neq[2] = 1; neq[3] = 1; + nek[2] = 1; nek[3] = 1; + nev[2] = 1; nev[3] = 1; + } else if (ndims == 3) { + neq[3] = 1; + nek[3] = 1; + nev[3] = 1; + } + x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f); + x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f); + x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f); + ggml_set_param(ctx0, x[0]); + ggml_set_param(ctx0, x[1]); + ggml_set_param(ctx0, x[2]); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); + + check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + } + } + } ggml_free(ctx0); }