CUDA: fix Volta FlashAttention logic (#11615)
This commit is contained in:
parent
d92cb67e37
commit
21c84b5d2d
2 changed files with 3 additions and 2 deletions
|
@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
// case 256:
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
// break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
return;
|
||||
}
|
||||
|
||||
if (!new_mma_available(cc)) {
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
|
@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue