CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)

This commit is contained in:
Johannes Gäßler 2024-06-01 15:47:04 +02:00 committed by GitHub
parent 9b596417af
commit 750f60c03e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 73 additions and 29 deletions

View file

@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[2];
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD || quantized_KV) {
if (cc >= CC_OFFSET_AMD) {
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {