From 1dd185751ea6bbdee418b4d90b63d9a948f15db7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 00:23:43 +0200 Subject: [PATCH] q8_0 k works --- ggml-cuda/fattn-tile-f16.cu | 36 +++++++++++++++++++++++++++--------- ggml-cuda/fattn.cu | 2 +- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 694ba81ba..1765ef183 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -50,11 +50,11 @@ static __global__ void flash_attn_tile_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; const int stride_K = nb11 / sizeof(type_k); - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_KV2 = nb11*qkk / (2*sizeof(type_k)); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -111,7 +111,7 @@ static __global__ void flash_attn_tile_ext_f16( const int k_KQ = k_KQ_0 + threadIdx.x; half2 tmp; - dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp); + dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp); KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -267,20 +267,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -289,6 +289,24 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } } + +template +void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + + switch (K->type) { + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + case GGML_TYPE_F16: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -299,18 +317,18 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); } diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index e6159ec69..046bed129 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -464,7 +464,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; - if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { + if (true || ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); return; }