q8_0 k works

This commit is contained in:
Johannes Gäßler 2024-05-20 00:23:43 +02:00
parent 08d8a6b528
commit 1dd185751e
2 changed files with 28 additions and 10 deletions

View file

@ -50,11 +50,11 @@ static __global__ void flash_attn_tile_ext_f16(
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_K = nb11 / sizeof(type_k);
const int stride_KV2 = nb11 / sizeof(half2);
const int stride_KV2 = nb11*qkk / (2*sizeof(type_k));
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@ -111,7 +111,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int k_KQ = k_KQ_0 + threadIdx.x;
half2 tmp;
dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp);
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp);
KV_tmp[i_KQ][k_KQ] = tmp;
}
}
@ -267,20 +267,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}
template <int cols_per_block, int parallel_blocks>
template <int cols_per_block, int parallel_blocks, typename type_k, int qkk, dequantize_kernel_t dequantize_k>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, half2, 2, convert_f16>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, half2, 2, convert_f16>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, dequantize_k>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
@ -289,6 +289,24 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];
switch (K->type) {
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, block_q8_0, QK8_0, dequantize_q8_0>(ctx, dst);
break;
case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, half2, 2, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);
break;
}
}
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
@ -299,18 +317,18 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
}

View file

@ -464,7 +464,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[2];
if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) {
if (true || ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
return;
}