From 47055c929fa4696a87c0ea10fc818d86359e622f Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 3 Jul 2023 18:45:54 +0200 Subject: [PATCH] tighten abs error bounds for flash_attn in test-grad0 --- tests/test-grad0.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-grad0.c b/tests/test-grad0.c index 0bbeff270..aba4b9c20 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -1493,7 +1493,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); } } } @@ -1534,7 +1534,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); } } }