diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 477435ebf..225303b54 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -428,7 +428,7 @@ void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int i *result = sum; } -__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* y, __global float* dst, const int ncols) { +__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) { //to rename it later, just to test now const uint16_t kmask1 = 0x3f3f; @@ -466,7 +466,7 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, const uint8_t * q1 = x[i].qs + q_offset; const uint8_t * q2 = q1 + 64; - const float * y1 = y + i*QK_K + y_offset; + const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; const float dall = vload_half(0, &x[i].d); @@ -562,6 +562,82 @@ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int i } +__kernel void dequantize_mul_mat_vec_q6_K_fast(__global struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) { + + const int row = get_group_id(0); + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const struct block_q6_K * x = xx + ib0; + + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + tmp[16 * ix + tid] = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = vload_half(0, &x[i].d); + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp[16 * ix + tid] += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp[16 * ix + tid] += sum; +#endif + + } + + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=16; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} + ); @@ -1041,7 +1117,7 @@ void ggml_cl_init(void) { CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K_fast", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err)); - CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K_fast", &err), err)); // mul kernel CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));