CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)
This commit is contained in:
parent
9b596417af
commit
750f60c03e
7 changed files with 73 additions and 29 deletions
|
@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
|
|||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue