q8_0 works

This commit is contained in:
Johannes Gäßler 2024-05-20 11:29:29 +02:00
parent 8a10e5c03c
commit 75096c6e6e
2 changed files with 8 additions and 4 deletions

View file

@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);

View file

@ -56,8 +56,8 @@ static __global__ void flash_attn_tile_ext_f16(
const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_K = nb11 / sizeof(type_k);
const int stride_V = nb11*qkk / (sizeof(type_v)*qkv);
const int stride_K = nb11/sizeof(type_k);
const int stride_V = stride_K*qkk/qkv;
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@ -318,8 +318,13 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * V = dst->src[2];
switch (V->type) {
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
break;
case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);