From 14e80c413bb411143f82f47f34b7e86c1575a7a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:46:02 +0200 Subject: [PATCH] q4_0 works --- ggml-cuda/fattn-tile-f16.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 734d676c2..c6925c2a6 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -273,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16( dst_val /= __half2half2(kqsum_j); } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + + if (qrv == 1) { + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + } else { + const int iqs = (i0%qkv)/qrv; + const int iybs = i0 - i0%qkv; + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val); + } } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -318,6 +326,10 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * V = dst->src[2]; switch (V->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst); + break; case GGML_TYPE_Q8_0: launch_fattn_tile_f16_64_128< cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);