cuda: 1.2x faster dequantization kernel

This commit is contained in:
lijiahao 2023-08-26 20:38:32 +08:00
parent a2ca4e9de9
commit 4c93e55996

View file

@ -4197,9 +4197,53 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
#ifdef GGML_CUDA_F16
return __hmul2(a, b);
#else
return make_float2(a.x * b.x, a.y * b.y);
#endif
}
static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) {
#ifdef GGML_CUDA_F16
return __half22float2(a);
#else
return a;
#endif
}
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i*4 >= k) {
return;
}
const int ib = i/(QK4_0/4);
const int iqs = i%(QK4_0/4);
const block_q4_0 * x = (const block_q4_0 *) vx;
const uchar2 qs = *(uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;
dfloat2 dv0;
dv0.x = (int)(qs.x & 0xf) - 8;
dv0.y = (int)(qs.y & 0xf) - 8;
float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;
dfloat2 dv1;
dv1.x = (int)(qs.x >> 4) - 8;
dv1.y = (int)(qs.y >> 4) - 8;
float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
GGML_ASSERT(k % 4 == 0);
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block_q4_0<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {