From 069cbe530d826b1b19559d1ded5032202032c287 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 20 Jun 2023 08:01:40 +0200 Subject: [PATCH] Fix q2_k fast kernel --- ggml-opencl.cpp | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 8a84b3453..13f603eca 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -21,6 +21,12 @@ #define CL_DMMV_BLOCK_SIZE 32 +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + #define MULTILINE_QUOTE(...) #__VA_ARGS__ static std::string program_source = MULTILINE_QUOTE( @@ -429,17 +435,18 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int block_size = get_local_size(0); const int row = get_group_id(0); const int num_blocks_per_row = ncols / 256; const int ib0 = row*num_blocks_per_row; - const int tid = get_local_id(0)/2; // 0...15 - const int ix = get_local_id(0)%2; + const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15 + const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int step = 8/K_QUANTS_PER_ITERATION; + + const int il = tid/step; // 0...3 + const int ir = tid - step*il;// 0...3 + const int n = 2*K_QUANTS_PER_ITERATION; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -448,15 +455,14 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; - uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; const struct block_q4_K * x = xx + ib0; - tmp[tid] = 0; + tmp[16 * ix + tid] = 0; - for (int i = ix; i < num_blocks_per_row; i += 2) { + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const uint8_t * q1 = x[i].qs + q_offset; const uint8_t * q2 = q1 + 64; @@ -472,20 +478,20 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - float4 s = {0.f, 0.f, 0.f, 0.f}; + float4 s = (float4)(0.f); float smin = 0; for (int l = 0; l < n; ++l) { s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; } - tmp[tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; + tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; } // sum up partial sums and write back result barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { + for (int s=16; s>0; s>>=1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -805,10 +811,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co exit(1); } - const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math " - "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1"; + std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math " + "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 " + "-DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION); - err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL); + err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); if(err < 0) { clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);