q4_0 works

This commit is contained in:
Johannes Gäßler 2024-05-20 11:46:02 +02:00
parent 75096c6e6e
commit 14e80c413b

View file

@ -273,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16(
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
if (qrv == 1) {
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
} else {
const int iqs = (i0%qkv)/qrv;
const int iybs = i0 - i0%qkv;
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val);
}
}
if (parallel_blocks != 1 && threadIdx.x == 0) {
@ -318,6 +326,10 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * V = dst->src[2];
switch (V->type) {
case GGML_TYPE_Q4_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
break;
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);