FP16 V still works

This commit is contained in:
Johannes Gäßler 2024-05-20 10:54:42 +02:00
parent 1b49f47c22
commit ca6d82885c

View file

@ -5,8 +5,9 @@
#define FATTN_KQ_STRIDE_TILE_F16 64 #define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, int parallel_blocks, template<int D, int ncols, int nwarps, int parallel_blocks, // D == head size
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> // D == head size typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) __launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -283,20 +284,24 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> template <int cols_per_block, int parallel_blocks,
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: { case 64: {
constexpr int D = 64; constexpr int D = 64;
constexpr int nwarps = 8; constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break; } break;
case 128: { case 128: {
constexpr int D = 128; constexpr int D = 128;
constexpr int nwarps = 8; constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break; } break;
default: { default: {
@ -305,6 +310,20 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
} }
} }
template <int cols_per_block, int parallel_blocks,
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * V = dst->src[2];
switch (V->type) {
case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);
break;
}
}
template <int cols_per_block, int parallel_blocks> template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -312,13 +331,13 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor *
switch (K->type) { switch (K->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst); launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
break; break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst); launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);