From 5aae4b8d4ff50fe11065835953b1b80a4b9fdbcc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 09:40:33 +0300 Subject: [PATCH] k_quants: WIP super-blocks with 64 weights Q4_K working on CUDA. ~10% slower on GTX-1660, 16% slower on 4080. --- ggml-cuda.cu | 91 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ddef2400a..2e86226fb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -142,13 +142,21 @@ typedef struct { } block_q3_K; //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +#ifdef GGML_QKK_64 +typedef struct { + half d[2*QK_K/32]; // super-block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif typedef struct { half d; // super-block scale for quantized scales @@ -420,13 +428,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const int i = blockIdx.x; - //// assume 64 threads - this is very slightly better than the one below - //const int tid = threadIdx.x; - //const int il = tid/16; - //const int ir = tid%16; - //const int is = 2*il; - //const int n = 2; - +#if QK_K == 256 // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; @@ -450,6 +452,13 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l +32] = d2 * (q[l] >> 4) - m2; } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1]; + y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3]; +#endif } static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { @@ -681,10 +690,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -693,6 +698,13 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const int il = tid/step; // 0...3 @@ -709,8 +721,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q4_K * x = (const block_q4_K *)vx + ib0; - float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { @@ -739,6 +749,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; } +#else + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const half2 * d = (const half2 *)x[i].d; + float2 df1 = __half22float2(d[0]); + float2 df2 = __half22float2(d[1]); + //float2 sum = {0.f, 0.f}; + //float smin = 0; + //for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + // sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF); + // sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4); + // smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]); + //} + //tmp += df1.x * sum.x + df2.x * sum.y - smin; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) + + y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y) + + y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y) + + y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y); + } + tmp += sum; + } + +#endif // sum up partial sums and write back result __syncthreads(); @@ -904,14 +944,14 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float #else - const int tid = threadIdx.x/4; // 0...7 - const int ix = threadIdx.x%4; // 0...3 + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 - const int step = tid; + const int step = tid * K_QUANTS_PER_ITERATION; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 4) { + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + step; const uint8_t * ql = x[i].ql + step; @@ -920,14 +960,13 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const float d = x[i+0].d; - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32) - + y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32) - + y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32) - + y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32) - + y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32); + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } tmp += sum; }