CUDA: quantized KV support for FA vec
This commit is contained in:
parent
10b1e45876
commit
672244a88b
11 changed files with 826 additions and 142 deletions
|
@ -45,6 +45,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
|
@ -457,14 +460,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (cc >= CC_OFFSET_AMD || quantized_KV) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue