FP16 V still works
This commit is contained in:
parent
1b49f47c22
commit
ca6d82885c
1 changed files with 27 additions and 8 deletions
|
@ -5,8 +5,9 @@
|
|||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks,
|
||||
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, // D == head size
|
||||
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
|
||||
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -283,20 +284,24 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
|
||||
template <int cols_per_block, int parallel_blocks,
|
||||
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
|
||||
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
|
||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
|
||||
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
|
||||
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -305,6 +310,20 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
}
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks,
|
||||
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
|
||||
void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
switch (V->type) {
|
||||
case GGML_TYPE_F16:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -312,13 +331,13 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue