q2_K
This commit is contained in:
parent
4b3af63ee8
commit
5bff3df032
1 changed files with 79 additions and 18 deletions
97
ggml-cuda.cu
97
ggml-cuda.cu
|
@ -182,8 +182,7 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||||
uint8_t qs[QK_K/4]; // quants
|
uint8_t qs[QK_K/4]; // quants
|
||||||
half d; // super-block scale for quantized scales
|
half2 dm; // super-block scale for quantized scales/mins
|
||||||
half dmin; // super-block scale for quantized mins
|
|
||||||
} block_q2_K;
|
} block_q2_K;
|
||||||
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
||||||
|
|
||||||
|
@ -505,8 +504,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
||||||
const uint8_t q = x[i].qs[32*n + l];
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
float * y = yy + i*QK_K + 128*n;
|
float * y = yy + i*QK_K + 128*n;
|
||||||
|
|
||||||
float dall = x[i].d;
|
float dall = x[i].dm.x;
|
||||||
float dmin = x[i].dmin;
|
float dmin = x[i].dm.y;
|
||||||
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||||
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||||
|
@ -516,8 +515,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
||||||
const int il = tid%16; // 0...15
|
const int il = tid%16; // 0...15
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
float * y = yy + i*QK_K + 16*is + il;
|
float * y = yy + i*QK_K + 16*is + il;
|
||||||
float dall = x[i].d;
|
float dall = x[i].dm.x;
|
||||||
float dmin = x[i].dmin;
|
float dmin = x[i].dm.y;
|
||||||
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||||
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
||||||
#endif
|
#endif
|
||||||
|
@ -755,8 +754,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||||
const float * y = yy + i * QK_K + y_offset;
|
const float * y = yy + i * QK_K + y_offset;
|
||||||
const uint8_t * q = x[i].qs + q_offset;
|
const uint8_t * q = x[i].qs + q_offset;
|
||||||
|
|
||||||
const float dall = x[i].d;
|
const float dall = x[i].dm.x;
|
||||||
const float dmin = x[i].dmin;
|
const float dmin = x[i].dm.y;
|
||||||
|
|
||||||
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
|
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
|
||||||
aux[0] = a[0] & 0x0f0f0f0f;
|
aux[0] = a[0] & 0x0f0f0f0f;
|
||||||
|
@ -798,9 +797,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||||
uaux[0] = s[0] & 0x0f0f0f0f;
|
uaux[0] = s[0] & 0x0f0f0f0f;
|
||||||
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
|
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
|
||||||
const half2 * dh = (const half2 *)&x[i].d;
|
const float2 dall = __half22float2(x[i].dm);
|
||||||
|
|
||||||
const float2 dall = __half22float2(dh[0]);
|
|
||||||
|
|
||||||
float sum1 = 0, sum2 = 0;
|
float sum1 = 0, sum2 = 0;
|
||||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||||
|
@ -1709,7 +1706,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
|
||||||
const int & v, const int * u, const uint8_t * scales, const float & d, const float & dmin, const float * d8) {
|
const int & v, const int * u, const uint8_t * scales, const half2 & dm, const float * d8) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
|
@ -1724,7 +1721,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
|
||||||
sumf_m += d8[i] * (__dp4a(0x01010101, u[i], 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values
|
sumf_m += d8[i] * (__dp4a(0x01010101, u[i], 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
const float2 dmf = __half22float2(dm);
|
||||||
|
|
||||||
|
return dmf.x*sumf_d - dmf.y*sumf_m;
|
||||||
#else
|
#else
|
||||||
return 0.0f; // only to satisfy the compiler
|
return 0.0f; // only to satisfy the compiler
|
||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
@ -1738,9 +1737,6 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
||||||
const int bq8_offset = QR2_K * (iqs / QI8_1);
|
const int bq8_offset = QR2_K * (iqs / QI8_1);
|
||||||
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
|
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
|
||||||
|
|
||||||
const float d = bq2_K->d;
|
|
||||||
const float dmin = bq2_K->dmin;
|
|
||||||
|
|
||||||
const uint8_t * scales = bq2_K->scales + scale_offset;
|
const uint8_t * scales = bq2_K->scales + scale_offset;
|
||||||
|
|
||||||
const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
|
const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
|
||||||
|
@ -1752,7 +1748,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
||||||
d8[i] = bq8_1[bq8_offset + i].ds.x;
|
d8[i] = bq8_1[bq8_offset + i].ds.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q2_K_q8_1_impl(v, u, scales, d, dmin, d8);
|
return vec_dot_q2_K_q8_1_impl(v, u, scales, bq2_K->dm, d8);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
|
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||||
|
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)];
|
||||||
|
__shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)];
|
||||||
|
|
||||||
|
*x_ql = tile_x_ql;
|
||||||
|
*x_dm = tile_x_dm;
|
||||||
|
*x_sc = tile_x_sc;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void load_tiles_q2_K(
|
||||||
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
|
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
const int kbx = k / QI2_K;
|
||||||
|
const int kqsx = k % QI2_K;
|
||||||
|
|
||||||
|
const block_q2_K * bx = ((block_q2_K *) vx) + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
|
||||||
|
x_dm[i * (WARP_SIZE / QI2_K) + kbx] = bx->dm;
|
||||||
|
x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->scales, kqsx / 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
||||||
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
const int kbx = k / QI2_K;
|
||||||
|
const int kqsx = k % QI2_K;
|
||||||
|
|
||||||
|
const int bq8_offset = QR2_K * (kqsx / QI8_1);
|
||||||
|
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
|
||||||
|
|
||||||
|
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16 + scale_offset;
|
||||||
|
|
||||||
|
int u[QR2_K];
|
||||||
|
float d8[QR2_K];
|
||||||
|
|
||||||
|
for (int l = 0; l < QR2_K; ++ l) {
|
||||||
|
const int y_qs_index = j * (QR2_K*WARP_SIZE) + kbx * (QR2_K*QI2_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
|
||||||
|
u[l] = y_qs[y_qs_index];
|
||||||
|
d8[l] = y_ds[y_qs_index / QI8_1].x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + kbx], d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl(
|
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl(
|
||||||
|
@ -3013,6 +3058,18 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||||
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||||
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||||
|
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||||
|
mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K, vec_dot_q2_K_q8_1_mul_mat>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_mul_mat_q6_K_q8_1_cuda(
|
static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
@ -3538,6 +3595,9 @@ inline void ggml_cuda_op_mul_mat_q(
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
|
ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
|
||||||
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
|
ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
|
@ -4309,7 +4369,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
|
||||||
} else {
|
} else {
|
||||||
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
|
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
|
||||||
src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_Q6_K) {
|
src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 ||
|
||||||
|
src0->type == GGML_TYPE_Q2_K || src0->type == GGML_TYPE_Q6_K) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue