Imrove Q2_K dot kernel on older GPUs

We now have a K_QUANTS_PER_ITERATION macro, which should be
set to 1 on older and to 2 on newer GPUs.
With this, we preserve the performance of the original
PR on RTX-4080, and are faster compared to master on
GTX-1660.
This commit is contained in:
Iwan Kawrakow 2023-06-16 11:27:09 +03:00
parent dc67f1a06e
commit 7ced197127

View file

@ -461,47 +461,70 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
} }
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols) { #ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#endif
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row;
const block_q2_K * x = (const block_q2_K *)vx + ib0; const block_q2_K * x = (const block_q2_K *)vx + ib0;
const int tid = threadIdx.x/2; // 0...15 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31
const int ix = threadIdx.x%2; // 0, 1 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0
const int n = 4; // iterations in the inner loop const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - 8*im; // 0...7
const int is = in/4;
const int l0 = n*in; // 0...28 in steps of 4 const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int q_offset = 32*im + l0; const int in = tid - step*im; // 0...7
const int s_offset = 8*im + is;
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) { uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
const uint8_t * s = x[i].scales + s_offset;
const float dall = x[i].d; const float dall = x[i].d;
const float dmin = x[i].dmin; const float dmin = x[i].dmin;
// This is about the same as the one below. const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
float sum = 0; aux[0] = a[0] & 0x0f0f0f0f;
for (int l = 0; l < n; ++l) { aux[1] = a[1] & 0x0f0f0f0f;
sum += y[l+ 0] * (dall * ((s[0] & 0xF) * ((q[l] >> 0) & 3)) - dmin * (s[0] >> 4)) aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+ y[l+32] * (dall * ((s[2] & 0xF) * ((q[l] >> 2) & 3)) - dmin * (s[2] >> 4)) aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+ y[l+64] * (dall * ((s[4] & 0xF) * ((q[l] >> 4) & 3)) - dmin * (s[4] >> 4))
+ y[l+96] * (dall * ((s[6] & 0xF) * ((q[l] >> 6) & 3)) - dmin * (s[6] >> 4)); float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
} }
tmp += sum; tmp += dall * sum1 - dmin * sum2;
} }
@ -736,32 +759,32 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
} }
} }
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
//const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row;
const block_q6_K * x = (const block_q6_K *)vx + ib0; const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/2; // 0...15 const int tid = threadIdx.x/1; // 0...31
const int ix = threadIdx.x%2; // 0, 1 const int ix = threadIdx.x%1; // 0
const int n = 4; // iterations in the inner loop const int n = 1; // iterations in the inner loop
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/16; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - 8*im; // 0...7 const int in = tid - 16*im; // 0...15
const int is = in/4;
const int l0 = n*in; // 0...28 in steps of 4 const int l0 = n*in; // 0...15
const int ql_offset = 64*im + l0; const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0; const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is; const int s_offset = 8*im;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) { for (int i = ix; i < num_blocks_per_row; i += 1) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset; const uint8_t * ql = x[i].ql + ql_offset;
@ -772,10 +795,22 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
float sum = 0; float sum = 0;
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) //sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) // + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) // + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); // + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32)
// + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | (((qh[l+16] >> 0) & 3) << 4)) - 32)
// + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | (((qh[l+16] >> 2) & 3) << 4)) - 32)
// + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | (((qh[l+16] >> 4) & 3) << 4)) - 32)
// +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | (((qh[l+16] >> 6) & 3) << 4)) - 32);
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32)
+ y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l+16] & 0x03) << 4)) - 32)
+ y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | ((qh[l+16] & 0x0c) << 2)) - 32)
+ y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l+16] & 0x30) >> 0)) - 32)
+y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | ((qh[l+16] & 0xc0) >> 2)) - 32);
} }
tmp += sum; tmp += sum;
@ -1210,8 +1245,11 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1); const int ny = 2;
dequantize_mul_mat_vec_q2_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@ -1234,8 +1272,11 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1); const int ny = 2;
dequantize_mul_mat_vec_q6_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {