From 75096c6e6ee51baffaa4de7c7f2fc889065fab9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:29:29 +0200 Subject: [PATCH] q8_0 works --- ggml-cuda/fattn-common.cuh | 1 - ggml-cuda/fattn-tile-f16.cu | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 4adbcc6f4..e08119c14 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 2e3baa793..734d676c2 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -56,8 +56,8 @@ static __global__ void flash_attn_tile_ext_f16( const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_K = nb11 / sizeof(type_k); - const int stride_V = nb11*qkk / (sizeof(type_v)*qkv); + const int stride_K = nb11/sizeof(type_k); + const int stride_V = stride_K*qkk/qkv; const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -318,8 +318,13 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * V = dst->src[2]; switch (V->type) { + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); + break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst); break; default: GGML_ASSERT(false);