Imrove Q6_K dot kernel on older GPUs

Using the same K_QUANTS_PER_ITERATION macro as last commit,
we preserve performance on RTX-4080 and speed up
Q6_K on a GTX-1660.
This commit is contained in:
Iwan Kawrakow 2023-06-16 12:43:20 +03:00
parent 7ced197127
commit 3edee085ea

View file

@ -167,6 +167,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define GGML_CUDA_DMMV_Y 1 #define GGML_CUDA_DMMV_Y 1
#endif #endif
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#endif
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -461,10 +465,6 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
} }
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#endif
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@ -761,6 +761,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.y*blockDim.y + threadIdx.y; const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return; if (row > nrows) return;
@ -769,22 +771,29 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
const block_q6_K * x = (const block_q6_K *)vx + ib0; const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/1; // 0...31 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%1; // 0 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int n = 1; // iterations in the inner loop const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/16; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - 16*im; // 0...15
const int l0 = n*in; // 0...15 const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0; const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0; const int qh_offset = 32*im + l0;
const int s_offset = 8*im; const int s_offset = 8*im + is;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 1) { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset; const uint8_t * ql = x[i].ql + ql_offset;
@ -793,26 +802,26 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
const float d = x[i].d; const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0; float sum = 0;
for (int l = 0; l < n; ++l) { for (int l = 0; l < 4; ++l) {
//sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
// + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
// + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
// + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32) + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
// + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | (((qh[l+16] >> 0) & 3) << 4)) - 32)
// + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | (((qh[l+16] >> 2) & 3) << 4)) - 32)
// + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | (((qh[l+16] >> 4) & 3) << 4)) - 32)
// +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | (((qh[l+16] >> 6) & 3) << 4)) - 32);
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32)
+ y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l+16] & 0x03) << 4)) - 32)
+ y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | ((qh[l+16] & 0x0c) << 2)) - 32)
+ y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l+16] & 0x30) >> 0)) - 32)
+y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | ((qh[l+16] & 0xc0) >> 2)) - 32);
} }
tmp += sum; tmp += sum;
#endif
} }
@ -1272,7 +1281,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);