From 1b49f47c2257038f8978c7bb7c85ea706619d15a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 10:47:38 +0200 Subject: [PATCH] q4_0 works --- ggml-cuda/fattn-tile-f16.cu | 43 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 1765ef183..95e7022ee 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -5,7 +5,8 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -48,7 +49,8 @@ static __global__ void flash_attn_tile_ext_f16( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const float2 * Q_f2 = (const float2 *) Q_f; const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; @@ -81,12 +83,26 @@ static __global__ void flash_attn_tile_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (qrk == 1) { #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + const int iqs = (i%qkk)/qrk; + const int iybs = i - i%qkk; + + float2 tmp; + tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2]; + tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2]; + Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } } } @@ -111,7 +127,7 @@ static __global__ void flash_attn_tile_ext_f16( const int k_KQ = k_KQ_0 + threadIdx.x; half2 tmp; - dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp); + dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp); KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -267,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -295,11 +311,14 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * K = dst->src[1]; switch (K->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_64_128(ctx, dst); + break; case GGML_TYPE_Q8_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); break; default: GGML_ASSERT(false);