CUDA: faster large batch FA without tensor cores (#7314)
This commit is contained in:
parent
82ca83db3c
commit
0fc1e820a9
7 changed files with 823 additions and 15 deletions
|
@ -1,5 +1,7 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
#include "fattn-tile-f32.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
@ -88,7 +90,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = blockIdx.y;
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
@ -541,13 +543,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue