q4_0 works

This commit is contained in:
Johannes Gäßler 2024-05-20 10:47:38 +02:00
parent 1dd185751e
commit 1b49f47c22

View file

@ -5,7 +5,8 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks,
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -48,7 +49,8 @@ static __global__ void flash_attn_tile_ext_f16(
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const float2 * Q_f2 = (const float2 *) Q_f;
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
@ -81,12 +83,26 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (qrk == 1) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
const int i = i0 + 2*threadIdx.x;
const int iqs = (i%qkk)/qrk;
const int iybs = i - i%qkk;
float2 tmp;
tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2];
tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2];
Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
}
@ -111,7 +127,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int k_KQ = k_KQ_0 + threadIdx.x;
half2 tmp;
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp);
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp);
KV_tmp[i_KQ][k_KQ] = tmp;
}
}
@ -267,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k>
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
@ -295,11 +311,14 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * K = dst->src[1];
switch (K->type) {
case GGML_TYPE_Q4_0:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
break;
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, dequantize_q8_0>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
break;
case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, convert_f16>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);