This commit is contained in:
JohannesGaessler 2023-07-27 21:39:04 +02:00
parent 5bff3df032
commit a62bcc891c

View file

@ -1861,6 +1861,61 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
return vec_dot_q3_K_q8_1_impl(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
}
static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)];
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE / 2)];
__shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_qh = tile_x_qh;
*x_sc = tile_x_sc;
}
static __device__ __forceinline__ void load_tiles_q3_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
const int kbx = k / QI3_K;
const int kqsx = k % QI3_K;
const block_q3_K * bx = ((block_q3_K *) vx) + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
x_dm[i * (WARP_SIZE / QI3_K) + kbx].x = bx->d;
x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->hmask, kqsx / 2);
x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8(bx->scales, kqsx / 4);
}
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
const int kbx = k / QI3_K;
const int kqsx = k % QI3_K;
const int bq8_offset = QR3_K * (kqsx / (QI3_K/2));
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16;
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
const int vh = ~x_qh[i * (WARP_SIZE/2) + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset;
int u[QR3_K];
float d8[QR3_K];
for (int l = 0; l < QR3_K; ++ l) {
const int y_qs_index = j * (QR3_K*WARP_SIZE) + kbx * (QR3_K*QI3_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
u[l] = y_qs[y_qs_index];
d8[l] = y_ds[y_qs_index / QI8_1].x;
}
return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI2_K) + kbx].x, d8);
}
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl(
const int * v, const int * u, const uint8_t * sc, const uint8_t * m, const half2 & dm4, const float * d8) {
@ -3070,6 +3125,18 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
static void ggml_mul_mat_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K, vec_dot_q3_K_q8_1_mul_mat>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@ -3598,6 +3665,9 @@ inline void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_Q2_K:
ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
break;
case GGML_TYPE_Q3_K:
ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
break;
case GGML_TYPE_Q6_K:
ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
break;
@ -4370,7 +4440,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
} else {
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 ||
src0->type == GGML_TYPE_Q2_K || src0->type == GGML_TYPE_Q6_K) {
src0->type == GGML_TYPE_Q2_K || src0->type == GGML_TYPE_Q3_K || src0->type == GGML_TYPE_Q6_K) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
} else {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);