metal : optimize FA kernels (#10171)
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
This commit is contained in:
parent
d05b3127bd
commit
841f27abdb
8 changed files with 498 additions and 339 deletions
|
@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
|
@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
|
@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
} else if(Q->ne[0] <= 128) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue