metal : use F16 precision in FA kernels

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-06 15:33:30 +02:00
parent 25e877309a
commit 7facc29d69
No known key found for this signature in database
GPG key ID: BF970631944C16B7
7 changed files with 476 additions and 333 deletions

View file

@ -3745,7 +3745,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
}
}