CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)
This commit is contained in:
parent
9b596417af
commit
750f60c03e
7 changed files with 73 additions and 29 deletions
|
@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD || quantized_KV) {
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue