tighten abs error bounds for flash_attn in test-grad0
This commit is contained in:
parent
dbbc263313
commit
47055c929f
1 changed files with 2 additions and 2 deletions
|
@ -1493,7 +1493,7 @@ int main(int argc, const char ** argv) {
|
|||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
||||
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1534,7 +1534,7 @@ int main(int argc, const char ** argv) {
|
|||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
||||
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue