metal : add tests, fix scaling, support C > 32

This commit is contained in:
Georgi Gerganov 2024-01-28 15:42:57 +02:00
parent 77f6976a87
commit ecc466a460
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 47 additions and 37 deletions

View file

@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case {
}
double max_nmse_err() override {
return 5e-5;
return 5e-4;
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1));
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8));
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7));
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1));
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer