k-quants : allow to optionally disable at compile time (#1734)
* k-quants : put behind optional compile flag LLAMA_K_QUANTS * build : enable k-quants by default
This commit is contained in:
parent
5b57a5b726
commit
5c64a0952e
6 changed files with 250 additions and 228 deletions
120
ggml-cuda.cu
120
ggml-cuda.cu
|
@ -110,24 +110,24 @@ typedef struct {
|
|||
uint8_t qs[QK_K/4]; // quants
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
} block_q2_k;
|
||||
static_assert(sizeof(block_q2_k) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_k block size/padding");
|
||||
} block_q2_K;
|
||||
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8];
|
||||
uint8_t qs[QK_K/4]; // nibbles / quants
|
||||
uint8_t scales[3*QK_K/64];
|
||||
half d;
|
||||
} block_q3_k;
|
||||
static_assert(sizeof(block_q3_k) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_k block size/padding");
|
||||
} block_q3_K;
|
||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_k;
|
||||
static_assert(sizeof(block_q4_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_k block size/padding");
|
||||
} block_q4_K;
|
||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
|
@ -135,16 +135,16 @@ typedef struct {
|
|||
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
||||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_k;
|
||||
static_assert(sizeof(block_q5_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_k block size/padding");
|
||||
} block_q5_K;
|
||||
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
||||
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales
|
||||
half d; // delta
|
||||
} block_q6_k;
|
||||
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
|
||||
} block_q6_K;
|
||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
|
@ -299,7 +299,7 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
|
|||
|
||||
//================================== k-quants
|
||||
|
||||
static __global__ void dequantize_block_q2_k(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -307,7 +307,7 @@ static __global__ void dequantize_block_q2_k(const void * vx, float * yy) {
|
|||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
|
||||
const block_q2_k * x = (const block_q2_k *) vx;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
float * y = yy + i*QK_K + 128*n;
|
||||
|
@ -321,9 +321,9 @@ static __global__ void dequantize_block_q2_k(const void * vx, float * yy) {
|
|||
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q2_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
|
||||
const block_q2_k * x = (const block_q2_k *) vx;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
// if n is 0, we want to do the lower 128, else the upper 128,
|
||||
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
|
||||
|
@ -352,7 +352,7 @@ static __device__ void vec_dot_q2_k(const void * vx, const int ib, const int iqs
|
|||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q3_k(const void * vx, float * yy) {
|
||||
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||
|
||||
int r = threadIdx.x/4;
|
||||
int i = blockIdx.x;
|
||||
|
@ -362,7 +362,7 @@ static __global__ void dequantize_block_q3_k(const void * vx, float * yy) {
|
|||
int n = tid / 4;
|
||||
int j = tid - 4*n;
|
||||
|
||||
const block_q3_k * x = (const block_q3_k *) vx;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j + is0;
|
||||
|
@ -383,9 +383,9 @@ static __global__ void dequantize_block_q3_k(const void * vx, float * yy) {
|
|||
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q3_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
|
||||
const block_q3_k * x = (const block_q3_k *) vx;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
@ -437,8 +437,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q4_k(const void * vx, float * yy) {
|
||||
const block_q4_k * x = (const block_q4_k *) vx;
|
||||
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
|
@ -474,9 +474,9 @@ static __global__ void dequantize_block_q4_k(const void * vx, float * yy) {
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q4_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
|
||||
const block_q4_k * x = (const block_q4_k *) vx;
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
// iqs is in 0...248 in steps of 8 =>
|
||||
const int j = iqs / 64; // j is in 0...3
|
||||
|
@ -506,8 +506,8 @@ static __device__ void vec_dot_q4_k(const void * vx, const int ib, const int iqs
|
|||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q5_k(const void * vx, float * yy) {
|
||||
const block_q5_k * x = (const block_q5_k *) vx;
|
||||
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
|
@ -539,9 +539,9 @@ static __global__ void dequantize_block_q5_k(const void * vx, float * yy) {
|
|||
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q5_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
|
||||
const block_q5_k * x = (const block_q5_k *) vx;
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
// iqs is in 0...248 in steps of 8 =>
|
||||
const int j = iqs / 64; // j is in 0...3
|
||||
|
@ -576,8 +576,8 @@ static __device__ void vec_dot_q5_k(const void * vx, const int ib, const int iqs
|
|||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q6_k(const void * vx, float * yy) {
|
||||
const block_q6_k * x = (const block_q6_k *) vx;
|
||||
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
|
@ -601,9 +601,9 @@ static __global__ void dequantize_block_q6_k(const void * vx, float * yy) {
|
|||
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q6_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
|
||||
|
||||
const block_q6_k * x = (const block_q6_k *) vx;
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int ip = iqs / 128; // 0 or 1
|
||||
const int il = (iqs - 128*ip)/8; // 0...15
|
||||
|
@ -804,29 +804,29 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|||
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q2_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q2_k<<<nb, 64, 0, stream>>>(vx, y);
|
||||
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_row_q3_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q3_k<<<nb, 64, 0, stream>>>(vx, y);
|
||||
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q4_k<<<nb, 32, 0, stream>>>(vx, y);
|
||||
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_row_q5_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q5_k<<<nb, 64, 0, stream>>>(vx, y);
|
||||
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_row_q6_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q6_k<<<nb, 64, 0, stream>>>(vx, y);
|
||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
|
@ -869,35 +869,35 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
|
|||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q2_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2;
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q2_k><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q3_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(32, 2, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q3_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(32, 2, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q4_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(32, 2, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q5_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q6_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(32, 2, 1);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q6_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
|
@ -926,15 +926,15 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_row_q8_0_cuda;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return dequantize_row_q2_k_cuda;
|
||||
return dequantize_row_q2_K_cuda;
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_k_cuda;
|
||||
return dequantize_row_q3_K_cuda;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return dequantize_row_q4_k_cuda;
|
||||
return dequantize_row_q4_K_cuda;
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_k_cuda;
|
||||
return dequantize_row_q5_K_cuda;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_k_cuda;
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_fp16_to_fp32_cuda;
|
||||
default:
|
||||
|
@ -1277,19 +1277,19 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|||
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
dequantize_mul_mat_vec_q2_k_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
dequantize_mul_mat_vec_q3_k_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
dequantize_mul_mat_vec_q4_k_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
dequantize_mul_mat_vec_q5_k_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
dequantize_mul_mat_vec_q6_k_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue