From ca6d82885cf8bcb87053519b6f8dfd6ea30aec47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 10:54:42 +0200 Subject: [PATCH] FP16 V still works --- ggml-cuda/fattn-tile-f16.cu | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 95e7022ee..0bcfda7b4 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -5,8 +5,9 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -283,20 +284,24 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -305,6 +310,20 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } } +template +void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * V = dst->src[2]; + + switch (V->type) { + case GGML_TYPE_F16: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} template void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -312,13 +331,13 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * switch (K->type) { case GGML_TYPE_Q4_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; case GGML_TYPE_Q8_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; default: GGML_ASSERT(false);