cuda : fix flash_attn kernel to produce same results as CPU

This commit is contained in:
Georgi Gerganov 2024-02-01 09:40:56 +02:00
parent fd878f71ed
commit 71b69aa7fd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 42 additions and 26 deletions

View file

@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int hs : { 128, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, 512 }) {
for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
}