Some more CUDA optimizations for Q3_K

Single token is now 20.5 ms/token (~20% slower than Q4_0).
Perplexity is on par with Q4_0.
This commit is contained in:
Iwan Kawrakow 2023-05-29 09:16:45 +03:00
parent a3c0673089
commit 3d8b1de3f7

View file

@ -202,43 +202,32 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
int r = threadIdx.x/4;
int i = blockIdx.x;
int tid = threadIdx.x;
int tid = r/2;
int is0 = r%2;
int l0 = 16*is0 + 4*(threadIdx.x%4);
int n = tid / 4;
int j = tid - 4*n;
const block_q3_K * x = (const block_q3_K *) vx;
float * y = yy + i*QK_K + 128*n + 32*j;
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
float d_all = x[i].d;
float dl = d_all * (us - 32);
float * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
memcpy(aux, x[i].scales, 12);
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j;
float dl;
int shift = 2*j;
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
dl = d_all * (scales[is++] - 32);
for (int l = 16; l < 32; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
}
@ -251,17 +240,17 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs
int j = iqsn / 32;
int l = iqsn - 32*j;
int shift = 2*j;
int is = 8*n + 2*j + l/16;
int iss = 2*j + l/16;
int is = 8*n + iss;
int is_shift = 2*(is/4);
uint8_t m = 1 << (4*n + j);
const float d = x[ib].d;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * hm = x[ib].hmask + l;
int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is+0] >> 4) & 3) << 4) :
(x[ib].scales[is-8] >> 4) | (((x[ib].scales[is-4] >> 6) & 3) << 4);
int8_t us = n == 0 ? (x[ib].scales[iss] & 0xF) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4)
: (x[ib].scales[iss] >> 4 ) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4);
float scale = d * (us - 32);
float sum = 0;
@ -346,19 +335,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
}
}
template <dot_kernel_k_t dot_kernel>
template <int n_thread, dot_kernel_k_t dot_kernel>
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int iter_stride = QK_K;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int vals_per_iter = iter_stride / n_thread;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
float tmp = 0; // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/QK_K; // x block index
const int ib = ib0 + col/QK_K; // x block index
const int iqs = col%QK_K; // x quant index
const int iybs = col - col%QK_K; // y block start index
@ -411,7 +402,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 8, 0, stream>>>(vx, y);
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@ -456,8 +447,8 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(WARP_SIZE, 2, 1);
dequantize_mul_mat_vec_k<vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
const dim3 block_dims(32, 2, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
}
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {