q4_0 works
This commit is contained in:
parent
1dd185751e
commit
1b49f47c22
1 changed files with 31 additions and 12 deletions
|
@ -5,7 +5,8 @@
|
|||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks,
|
||||
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -48,7 +49,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) Q_f;
|
||||
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
@ -81,12 +83,26 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (qrk == 1) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
const int iqs = (i%qkk)/qrk;
|
||||
const int iybs = i - i%qkk;
|
||||
|
||||
float2 tmp;
|
||||
tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2];
|
||||
tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2];
|
||||
Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,7 +127,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
half2 tmp;
|
||||
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp);
|
||||
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp);
|
||||
KV_tmp[i_KQ][k_KQ] = tmp;
|
||||
}
|
||||
}
|
||||
|
@ -267,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k>
|
||||
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
|
||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -295,11 +311,14 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
const ggml_tensor * K = dst->src[1];
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, dequantize_q8_0>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, convert_f16>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue