diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f9f7042c4..70dcb19db 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -167,6 +167,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define GGML_CUDA_DMMV_Y 1 #endif +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 1 +#endif + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -461,10 +465,6 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 1 -#endif - static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -761,6 +761,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -769,22 +771,29 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const block_q6_K * x = (const block_q6_K *)vx + ib0; - const int tid = threadIdx.x/1; // 0...31 - const int ix = threadIdx.x%1; // 0 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - const int n = 1; // iterations in the inner loop - const int im = tid/16; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - 16*im; // 0...15 + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - const int l0 = n*in; // 0...15 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; - const int s_offset = 8*im; + const int s_offset = 8*im + is; const int y_offset = 128*im + l0; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 1) { + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -793,26 +802,26 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const float d = x[i].d; +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else float sum = 0; - for (int l = 0; l < n; ++l) { - //sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - // + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - // + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - // + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32) - // + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | (((qh[l+16] >> 0) & 3) << 4)) - 32) - // + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | (((qh[l+16] >> 2) & 3) << 4)) - 32) - // + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | (((qh[l+16] >> 4) & 3) << 4)) - 32) - // +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | (((qh[l+16] >> 6) & 3) << 4)) - 32); - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32) - + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l+16] & 0x03) << 4)) - 32) - + y[l+48] * s[3] * d * ((int8_t)((ql[l+48] & 0xF) | ((qh[l+16] & 0x0c) << 2)) - 32) - + y[l+80] * s[5] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l+16] & 0x30) >> 0)) - 32) - +y[l+112] * s[7] * d * ((int8_t)((ql[l+48] >> 4) | ((qh[l+16] & 0xc0) >> 2)) - 32); + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); } tmp += sum; +#endif } @@ -1272,7 +1281,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; + const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1);