diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a59110291..a6bfc3ca7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -202,43 +202,32 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - + int r = threadIdx.x/4; int i = blockIdx.x; - int tid = threadIdx.x; + int tid = r/2; + int is0 = r%2; + int l0 = 16*is0 + 4*(threadIdx.x%4); int n = tid / 4; int j = tid - 4*n; const block_q3_K * x = (const block_q3_K *) vx; - float * y = yy + i*QK_K + 128*n + 32*j; + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); float d_all = x[i].d; + float dl = d_all * (us - 32); + float * y = yy + i*QK_K + 128*n + 32*j; const uint8_t * q = x[i].qs + 32*n; const uint8_t * hm = x[i].hmask; - uint32_t aux[4]; - const int8_t * scales = (const int8_t*)aux; - - memcpy(aux, x[i].scales, 12); - uint32_t tmp = aux[2]; - aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j; - float dl; - int shift = 2*j; - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); - - dl = d_all * (scales[is++] - 32); - for (int l = 16; l < 32; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); } @@ -251,17 +240,17 @@ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs int j = iqsn / 32; int l = iqsn - 32*j; int shift = 2*j; - int is = 8*n + 2*j + l/16; + int iss = 2*j + l/16; + int is = 8*n + iss; + int is_shift = 2*(is/4); uint8_t m = 1 << (4*n + j); const float d = x[ib].d; const uint8_t * q = x[ib].qs + 32*n + l; const uint8_t * hm = x[ib].hmask + l; - int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is+0] >> 4) & 3) << 4) : - (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is-4] >> 6) & 3) << 4); + int8_t us = n == 0 ? (x[ib].scales[iss] & 0xF) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4) + : (x[ib].scales[iss] >> 4 ) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4); float scale = d * (us - 32); float sum = 0; @@ -346,19 +335,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } -template +template static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; const int iter_stride = QK_K; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int vals_per_iter = iter_stride / n_thread; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; float tmp = 0; // partial sum for thread in warp for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/QK_K; // x block index + const int ib = ib0 + col/QK_K; // x block index const int iqs = col%QK_K; // x quant index const int iybs = col - col%QK_K; // y block start index @@ -411,7 +402,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; - dequantize_block_q3_K<<>>(vx, y); + dequantize_block_q3_K<<>>(vx, y); } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -456,8 +447,8 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(WARP_SIZE, 2, 1); - dequantize_mul_mat_vec_k<<>>(vx, y, dst, ncols); + const dim3 block_dims(32, 2, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {