q4_0 works
This commit is contained in:
parent
75096c6e6e
commit
14e80c413b
1 changed files with 14 additions and 2 deletions
|
@ -273,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
dst_val /= __half2half2(kqsum_j);
|
dst_val /= __half2half2(kqsum_j);
|
||||||
}
|
}
|
||||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
|
||||||
|
if (qrv == 1) {
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||||
|
} else {
|
||||||
|
const int iqs = (i0%qkv)/qrv;
|
||||||
|
const int iybs = i0 - i0%qkv;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val);
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||||
|
@ -318,6 +326,10 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
switch (V->type) {
|
switch (V->type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
launch_fattn_tile_f16_64_128<
|
||||||
|
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
launch_fattn_tile_f16_64_128<
|
launch_fattn_tile_f16_64_128<
|
||||||
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue