metal : optimize FA kernels (#10171)

* ggml : add ggml_flash_attn_ext_get_prec

* metal : use F16 precision in FA kernels

ggml-ci

* metal : minor clean-up

* metal : compile-guard bf16 FA kernels

ggml-ci

* build : remove obsolete compile flag [no ci]

* metal : prevent int overflows [no ci]

* cuda : disable BF16 FA

ggml-ci

* metal : fix BF16 requirement for FA kernels

ggml-ci

* make : clean-up [no ci]
This commit is contained in:
Georgi Gerganov 2024-11-08 13:47:22 +02:00 committed by GitHub
parent d05b3127bd
commit 841f27abdb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 498 additions and 339 deletions

View file

@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a) {
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
return (enum ggml_prec) prec_i32;
}
// ggml_flash_attn_back
struct ggml_tensor * ggml_flash_attn_back(