From 7ced19712726af6d6d3b4bf451e85f986a78ac2e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 16 Jun 2023 11:27:09 +0300 Subject: [PATCH] Imrove Q2_K dot kernel on older GPUs We now have a K_QUANTS_PER_ITERATION macro, which should be set to 1 on older and to 2 on newer GPUs. With this, we preserve the performance of the original PR on RTX-4080, and are faster compared to master on GTX-1660. --- ggml-cuda.cu | 123 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 82 insertions(+), 41 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b77020e06..f9f7042c4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -461,47 +461,70 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } -static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols) { +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 1 +#endif + +static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; - const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; const block_q2_K * x = (const block_q2_K *)vx + ib0; - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; // 0, 1 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 - const int n = 4; // iterations in the inner loop - const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - 8*im; // 0...7 - const int is = in/4; + const int step = 16/K_QUANTS_PER_ITERATION; - const int l0 = n*in; // 0...28 in steps of 4 - const int q_offset = 32*im + l0; - const int s_offset = 8*im + is; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; const int y_offset = 128*im + l0; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; - const uint8_t * s = x[i].scales + s_offset; const float dall = x[i].d; const float dmin = x[i].dmin; - // This is about the same as the one below. - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (dall * ((s[0] & 0xF) * ((q[l] >> 0) & 3)) - dmin * (s[0] >> 4)) - + y[l+32] * (dall * ((s[2] & 0xF) * ((q[l] >> 2) & 3)) - dmin * (s[2] >> 4)) - + y[l+64] * (dall * ((s[4] & 0xF) * ((q[l] >> 4) & 3)) - dmin * (s[4] >> 4)) - + y[l+96] * (dall * ((s[6] & 0xF) * ((q[l] >> 6) & 3)) - dmin * (s[6] >> 4)); + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + } - tmp += sum; + tmp += dall * sum1 - dmin * sum2; } @@ -736,32 +759,32 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols) { +static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; - //const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; const block_q6_K * x = (const block_q6_K *)vx + ib0; - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; // 0, 1 + const int tid = threadIdx.x/1; // 0...31 + const int ix = threadIdx.x%1; // 0 - const int n = 4; // iterations in the inner loop - const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - 8*im; // 0...7 - const int is = in/4; + const int n = 1; // iterations in the inner loop + const int im = tid/16; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - 16*im; // 0...15 - const int l0 = n*in; // 0...28 in steps of 4 + const int l0 = n*in; // 0...15 const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; + const int s_offset = 8*im; const int y_offset = 128*im + l0; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { + for (int i = ix; i < num_blocks_per_row; i += 1) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -772,10 +795,22 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float float sum = 0; for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + //sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + // + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + // + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + // + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32) + // + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | (((qh[l+16] >> 0) & 3) << 4)) - 32) + // + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | (((qh[l+16] >> 2) & 3) << 4)) - 32) + // + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | (((qh[l+16] >> 4) & 3) << 4)) - 32) + // +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | (((qh[l+16] >> 6) & 3) << 4)) - 32); + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32) + + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l+16] & 0x03) << 4)) - 32) + + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | ((qh[l+16] & 0x0c) << 2)) - 32) + + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l+16] & 0x30) >> 0)) - 32) + +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | ((qh[l+16] & 0xc0) >> 2)) - 32); } tmp += sum; @@ -1210,8 +1245,11 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -1234,8 +1272,11 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {