q4_0 works
This commit is contained in:
parent
75096c6e6e
commit
14e80c413b
1 changed files with 14 additions and 2 deletions
|
@ -273,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||
|
||||
if (qrv == 1) {
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||
} else {
|
||||
const int iqs = (i0%qkv)/qrv;
|
||||
const int iybs = i0 - i0%qkv;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val);
|
||||
}
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
|
@ -318,6 +326,10 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
switch (V->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
launch_fattn_tile_f16_64_128<
|
||||
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
launch_fattn_tile_f16_64_128<
|
||||
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue