Some more CUDA optimizations for Q3_K
Single token is now 20.5 ms/token (~20% slower than Q4_0). Perplexity is on par with Q4_0.
This commit is contained in:
parent
a3c0673089
commit
3d8b1de3f7
1 changed files with 27 additions and 36 deletions
63
ggml-cuda.cu
63
ggml-cuda.cu
|
@ -202,43 +202,32 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
|
|||
|
||||
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
int r = threadIdx.x/4;
|
||||
int i = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int tid = r/2;
|
||||
int is0 = r%2;
|
||||
int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||
int n = tid / 4;
|
||||
int j = tid - 4*n;
|
||||
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j + is0;
|
||||
int shift = 2*j;
|
||||
|
||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
|
||||
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
|
||||
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
|
||||
float d_all = x[i].d;
|
||||
float dl = d_all * (us - 32);
|
||||
|
||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
uint32_t aux[4];
|
||||
const int8_t * scales = (const int8_t*)aux;
|
||||
|
||||
memcpy(aux, x[i].scales, 12);
|
||||
uint32_t tmp = aux[2];
|
||||
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
||||
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j;
|
||||
float dl;
|
||||
int shift = 2*j;
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 16; l < 32; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
|
||||
}
|
||||
|
||||
|
@ -251,17 +240,17 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs
|
|||
int j = iqsn / 32;
|
||||
int l = iqsn - 32*j;
|
||||
int shift = 2*j;
|
||||
int is = 8*n + 2*j + l/16;
|
||||
int iss = 2*j + l/16;
|
||||
int is = 8*n + iss;
|
||||
int is_shift = 2*(is/4);
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
|
||||
const float d = x[ib].d;
|
||||
const uint8_t * q = x[ib].qs + 32*n + l;
|
||||
const uint8_t * hm = x[ib].hmask + l;
|
||||
|
||||
int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) :
|
||||
is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) :
|
||||
is < 12 ? (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is+0] >> 4) & 3) << 4) :
|
||||
(x[ib].scales[is-8] >> 4) | (((x[ib].scales[is-4] >> 6) & 3) << 4);
|
||||
int8_t us = n == 0 ? (x[ib].scales[iss] & 0xF) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4)
|
||||
: (x[ib].scales[iss] >> 4 ) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4);
|
||||
float scale = d * (us - 32);
|
||||
|
||||
float sum = 0;
|
||||
|
@ -346,19 +335,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
}
|
||||
}
|
||||
|
||||
template <dot_kernel_k_t dot_kernel>
|
||||
template <int n_thread, dot_kernel_k_t dot_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int iter_stride = QK_K;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||
const int vals_per_iter = iter_stride / n_thread;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = 0; i < ncols; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/QK_K; // x block index
|
||||
const int ib = ib0 + col/QK_K; // x block index
|
||||
const int iqs = col%QK_K; // x quant index
|
||||
const int iybs = col - col%QK_K; // y block start index
|
||||
|
||||
|
@ -411,7 +402,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|||
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q3_K<<<nb, 8, 0, stream>>>(vx, y);
|
||||
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
|
@ -456,8 +447,8 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
|
|||
|
||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(WARP_SIZE, 2, 1);
|
||||
dequantize_mul_mat_vec_k<vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
const dim3 block_dims(32, 2, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue