From 1506affd0a96dca301b72230df11fdec6f2ea7e8 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sun, 11 Jun 2023 22:29:43 +0800 Subject: [PATCH] Added q6_k kernel --- ggml-opencl.cpp | 81 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 25 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 44d26d968..81173c495 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -61,28 +61,36 @@ struct __attribute__ ((packed)) block_q8_0 struct __attribute__ ((packed)) block_q2_K { - uchar scales[16]; - uchar qs[64]; - half d; - half dmin; -}; - -struct __attribute__ ((packed)) block_q4_K -{ - uchar scales[12]; - uchar qs[128]; + uint8_t scales[16]; + uint8_t qs[64]; half d; half dmin; }; struct __attribute__ ((packed)) block_q3_K { - uchar hmask[32]; - uchar qs[64]; - uchar scales[12]; + uint8_t hmask[32]; + uint8_t qs[64]; + uint8_t scales[12]; half d; }; +struct __attribute__ ((packed)) block_q4_K +{ + uint8_t scales[12]; + uint8_t qs[128]; + half d; + half dmin; +}; + +struct __attribute__ ((packed)) block_q6_K +{ + uint8_t ql[128]; + uint8_t qh[64]; + int8_t scales[16]; + half d; +}; + __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) { const uint i = get_global_id(0); @@ -172,7 +180,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K* x, __globa const int l = tid - 32 * n; const int is = 8 * n + l / 16; - const uchar q = x[i].qs[32 * n + l]; + const uint8_t q = x[i].qs[32 * n + l]; __global float *y = yy + i * 256 + 128 * n; const float dall = vload_half(0, &x[i].d); @@ -193,11 +201,11 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K* x, __globa int n = tid / 4; int j = tid - 4 * n; - uchar m = 1 << (4 * n + j); + uint8_t m = 1 << (4 * n + j); int is = 8 * n + 2 * j + is0; int shift = 2 * j; - uchar us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4) : + int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4) : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4) : is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4) : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4); @@ -205,11 +213,12 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K* x, __globa float dl = d_all * (us - 32); __global float *y = yy + i * 256 + 128 * n + 32 * j; - const __global uchar *q = x[i].qs + 32 * n; - const __global uchar *hm = x[i].hmask; + const __global uint8_t *q = x[i].qs + 32 * n; + const __global uint8_t *hm = x[i].hmask; for (int l = l0; l < l0 + 4; ++l) - y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l / 8] & m) ? 0 : 4)); + y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } __kernel void dequantize_block_q4_K(__global const struct block_q4_K* x, __global float *yy) { @@ -226,9 +235,9 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K* x, __globa const float dall = vload_half(0, &x[i].d); const float dmin = vload_half(0, &x[i].dmin); - __global const uchar *q = x[i].qs + 32 * il + n * ir; + __global const uint8_t *q = x[i].qs + 32 * il + n * ir; - uchar sc, m; + uint8_t sc, m; get_scale_min_k4(is + 0, x[i].scales, &sc, &m); float d1 = dall * sc; float m1 = dmin * m; @@ -241,6 +250,28 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K* x, __globa } } +__kernel void dequantize_block_q6_K(__global const struct block_q6_K* x, __global float *yy) { + + const int i = get_group_id(0); + const int tid = get_local_id(0); + const int ip = tid / 32; + const int il = tid - 32 * ip; + const int is = 8 * ip + il / 16; + + __global float* y = yy + i * 256 + 128 * ip + il; + + const float d = vload_half(0, &x[i].d); + + __global const uint8_t * ql = x[i].ql + 64 * ip + il; + const uint8_t qh = x[i].qh[32 * ip + il]; + __global const int8_t * sc = x[i].scales + is; + + y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + ); // __kernel void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, __global float *result) { @@ -474,7 +505,7 @@ static cl_program program; static cl_kernel convert_row_f16_cl; static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl; static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl; -static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl; +static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q6_k_cl; static cl_kernel mul_f32_cl; static bool fp16_support; @@ -708,6 +739,7 @@ void ggml_cl_init(void) { CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err)); CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err)); CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err)); + CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err)); // dequant mul mat kernel CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err)); @@ -740,6 +772,8 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) { return &dequantize_block_q3_k_cl; case GGML_TYPE_Q4_K: return &dequantize_block_q4_k_cl; + case GGML_TYPE_Q6_K: + return &dequantize_block_q6_k_cl; case GGML_TYPE_F16: return &convert_row_f16_cl; default: @@ -1257,9 +1291,6 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * d_Q = ggml_cl_pool_malloc(q_sz, &q_size); } - printf("\ntype:%d q_sz:%d y_sz:%d ne00:%d ne01:%d ne10:%d ne11:%d nb2:%d nb3:%d",type,q_size,y_size,ne00,ne01,ne10,ne11); - fflush(stdout); - cl_kernel* to_fp32_cl = ggml_get_to_fp32_cl(type); cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type); GGML_ASSERT(to_fp32_cl != nullptr);