diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d1ae5076a..b6635e413 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1913,7 +1913,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( d8[l] = y_ds[y_qs_index / QI8_1].x; } - return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI2_K) + kbx].x, d8); + return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI3_K) + kbx].x, d8); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl( @@ -2034,6 +2034,79 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( #endif } +static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_K)]; + __shared__ int tile_x_sc[(2*WARP_SIZE) * (3*WARP_SIZE/32)]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI4_K; + const int kqsx = k % QI4_K; + + const block_q4_K * bx = ((block_q4_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm; + x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32)); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + if (k >= WARP_SIZE/2) { + return 0.0f; + } + + __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * (kqsx / (QI8_1/2)); + + v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4]; + v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4]; + + const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)]; + uint16_t aux[2]; + const int l = bq8_offset/2; + if (l < 2) { + aux[0] = scales[l+0] & 0x3f3f; + aux[1] = scales[l+2] & 0x3f3f; + } else { + aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2); + aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int l = 0; l < QR4_K; ++l) { + const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2); + u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)]; + u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)]; + d8[l] = y_ds[kqsy / QI8_1].x; + } + + return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8); +} + static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl( const int * vl, const int * vh, const int * u, const uint8_t * sc, const uint8_t * m, const half2 & dm5, const float * d8) { @@ -3137,6 +3210,18 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +static void ggml_mul_mat_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -3668,6 +3753,9 @@ inline void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q3_K: ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; case GGML_TYPE_Q6_K: ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); break; @@ -4440,7 +4528,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } else { if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || - src0->type == GGML_TYPE_Q2_K || src0->type == GGML_TYPE_Q3_K || src0->type == GGML_TYPE_Q6_K) { + src0->type == GGML_TYPE_Q2_K || src0->type == GGML_TYPE_Q3_K || src0->type == GGML_TYPE_Q4_K || + src0->type == GGML_TYPE_Q6_K) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);