CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)

This commit is contained in:
Johannes Gäßler 2024-06-01 15:47:04 +02:00 committed by GitHub
parent 9b596417af
commit 750f60c03e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 73 additions and 29 deletions

View file

@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>