q4_0 works
This commit is contained in:
parent
1dd185751e
commit
1b49f47c22
1 changed files with 31 additions and 12 deletions
|
@ -5,7 +5,8 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks,
|
||||||
|
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -48,7 +49,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||||
|
const float2 * Q_f2 = (const float2 *) Q_f;
|
||||||
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
@ -81,12 +83,26 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (qrk == 1) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||||
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||||
|
const int i = i0 + 2*threadIdx.x;
|
||||||
|
const int iqs = (i%qkk)/qrk;
|
||||||
|
const int iybs = i - i%qkk;
|
||||||
|
|
||||||
|
float2 tmp;
|
||||||
|
tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2];
|
||||||
|
tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2];
|
||||||
|
Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +127,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
half2 tmp;
|
half2 tmp;
|
||||||
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp);
|
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp);
|
||||||
KV_tmp[i_KQ][k_KQ] = tmp;
|
KV_tmp[i_KQ][k_KQ] = tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -267,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k>
|
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
|
||||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -295,11 +311,14 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
|
||||||
switch (K->type) {
|
switch (K->type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, dequantize_q8_0>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, convert_f16>(ctx, dst);
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue