q8_0 works
This commit is contained in:
parent
8a10e5c03c
commit
75096c6e6e
2 changed files with 8 additions and 4 deletions
|
@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
|
|
@ -57,7 +57,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_K = nb11/sizeof(type_k);
|
||||
const int stride_V = nb11*qkk / (sizeof(type_v)*qkv);
|
||||
const int stride_V = stride_K*qkk/qkv;
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
@ -318,8 +318,13 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
switch (V->type) {
|
||||
case GGML_TYPE_Q8_0:
|
||||
launch_fattn_tile_f16_64_128<
|
||||
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
launch_fattn_tile_f16_64_128<
|
||||
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue