k_quants: hopefully much faster Q3_K on older GPUs

On the GTX-1660 that I have available to represent
"old GPUs", token prediction drops from 60.3 ms/tok
to 41.0 ms/tok!
This commit is contained in:
Iwan Kawrakow 2023-06-18 16:50:15 +03:00
parent 1677059ba1
commit be6f8b9ee7

View file

@ -546,27 +546,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
} }
} }
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int row = blockIdx.x; const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0; const block_q3_K * x = (const block_q3_K *)vx + ib0;
const int tid = threadIdx.x/2; // 0...15 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%2; // 0, 1 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int n = 2; // iterations in the inner loop const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const int step = 16/K_QUANTS_PER_ITERATION;
const int in = tid - 8*im; // 0...7 const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7
const uint8_t m = 1 << (4*im); const uint8_t m = 1 << (4*im);
const int l0 = n*in; // 0...28 in steps of 4 const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
@ -577,7 +580,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -1272,8 +1275,11 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1); const int ny = 2 / K_QUANTS_PER_ITERATION;
dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {