diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 663e16c4a..d8e8d25c6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -159,12 +159,12 @@ typedef struct { static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); -typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); typedef void (*load_tiles_cuda_t)( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row); + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row); typedef float (*vec_dot_q_mul_mat_cuda_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); //================================= k-quants @@ -1390,7 +1390,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds); } -static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { +static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI4_0)]; @@ -1401,7 +1401,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** static __device__ __forceinline__ void load_tiles_q4_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { const int kbx = k / QI4_0; const int kqsx = k % QI4_0; @@ -1413,7 +1413,7 @@ static __device__ __forceinline__ void load_tiles_q4_0( } static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); @@ -1462,7 +1462,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds); } -static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { +static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_1)]; @@ -1473,7 +1473,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** static __device__ __forceinline__ void load_tiles_q4_1( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { const int kbx = k / QI4_1; const int kqsx = k % QI4_1; @@ -1485,7 +1485,7 @@ static __device__ __forceinline__ void load_tiles_q4_1( } static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); @@ -1532,7 +1532,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds); } -static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { +static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)]; @@ -1545,7 +1545,7 @@ static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** static __device__ __forceinline__ void load_tiles_q5_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { const int kbx = k / QI5_0; const int kqsx = k % QI5_0; @@ -1558,7 +1558,7 @@ static __device__ __forceinline__ void load_tiles_q5_0( } static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); @@ -1616,7 +1616,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( return vec_dot_q5_1_q8_1_impl(qs, qh, ui0, ui1, bq5_1->dm, bq8_1->ds); } -static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { +static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_1)]; @@ -1629,7 +1629,7 @@ static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** static __device__ __forceinline__ void load_tiles_q5_1( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { const int kbx = k / QI5_1; const int kqsx = k % QI5_1; @@ -1642,7 +1642,7 @@ static __device__ __forceinline__ void load_tiles_q5_1( } static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); @@ -1677,7 +1677,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds); } -static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { +static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI8_0)]; @@ -1688,7 +1688,7 @@ static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** static __device__ __forceinline__ void load_tiles_q8_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { const int kbx = k / QI8_0; const int kqsx = k % QI8_0; @@ -1700,7 +1700,7 @@ static __device__ __forceinline__ void load_tiles_q8_0( } static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { return vec_dot_q8_0_q8_1_impl( @@ -2055,7 +2055,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl( - const int & vl, const int & vh, const int * u, const int8_t * scales, const float & d, const float * d8) { + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf = 0.0f; @@ -2103,6 +2104,66 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( return vec_dot_q6_K_q8_1_impl(vl, vh, u, scales, bq6_K->d, d8); } +static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI6_K)]; + __shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/2)]; + __shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE/8)]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI6_K; + const int kqsx = k % QI6_K; + + const block_q6_K * bx = ((block_q6_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->ql, kqsx); + x_dm[i * (WARP_SIZE / QI6_K) + kbx].x = bx->d; + x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->qh, kqsx/2); + x_sc[i * (WARP_SIZE / 8) + k/8] = get_int_from_int8(bx->scales, kqsx/8); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const int bq8_offset = 2 * QR6_K * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)); + + const int vh = x_qh[i * (WARP_SIZE/2) + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift; + + const int x_sc_offset = i * (WARP_SIZE/8) + kbx * (QI6_K/8); + const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + + for (int l = 0; l < QR6_K; ++l) { + const int kqsy = j * (QR6_K*WARP_SIZE) + kbx * (QR6_K*QI6_K) + (bq8_offset + 2*l)*QI8_1 + kqsx % QI8_1; + u[l] = y_qs[kqsy]; + d8[l] = y_ds[kqsy / QI8_1].x; + } + + return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, x_dm[i * (WARP_SIZE/QI6_K) + kbx].x, d8); +} + template static __global__ void mul_mat_q( @@ -2113,7 +2174,7 @@ static __global__ void mul_mat_q( const block_q8_1 * y = (const block_q8_1 *) vy; const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / qk; + const int blocks_per_col_y = nrows_y / QK8_1; const int blocks_per_warp = WARP_SIZE / qi; const int & ncols_dst = ncols_y; @@ -2128,10 +2189,10 @@ static __global__ void mul_mat_q( const int col_dst_0 = blockIdx.y*WARP_SIZE; const int & col_y_0 = col_dst_0; - int * tile_x_ql = nullptr; - half2 * tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int8_t * tile_x_sc = nullptr; + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); @@ -2156,7 +2217,7 @@ static __global__ void mul_mat_q( for (int i = 0; i < WARP_SIZE; i += 8) { const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses - const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 + kby]; + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby]; tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = *((int *) &by0->qs[iqsy]); tile_y_ds[(tid_y + i) * (qr*WARP_SIZE/QI8_1) + kby] = by0->ds; @@ -2952,6 +3013,18 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +static void ggml_mul_mat_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + static void ggml_mul_mat_p021_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, cudaStream_t stream) { @@ -3465,6 +3538,9 @@ inline void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q8_0: ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; default: GGML_ASSERT(false); break; @@ -4233,7 +4309,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || - src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0) { + src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_Q6_K) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);