Fix q2_k fast kernel

This commit is contained in:
0cc4m 2023-06-20 08:01:40 +02:00
parent 69fd31d18c
commit 069cbe530d

View file

@ -21,6 +21,12 @@
#define CL_DMMV_BLOCK_SIZE 32
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
#define MULTILINE_QUOTE(...) #__VA_ARGS__
static std::string program_source = MULTILINE_QUOTE(
@ -429,17 +435,18 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int block_size = get_local_size(0);
const int row = get_group_id(0);
const int num_blocks_per_row = ncols / 256;
const int ib0 = row*num_blocks_per_row;
const int tid = get_local_id(0)/2; // 0...15
const int ix = get_local_id(0)%2;
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int step = 8/K_QUANTS_PER_ITERATION;
const int il = tid/step; // 0...3
const int ir = tid - step*il;// 0...3
const int n = 2*K_QUANTS_PER_ITERATION;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
@ -448,15 +455,14 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
const struct block_q4_K * x = xx + ib0;
tmp[tid] = 0;
tmp[16 * ix + tid] = 0;
for (int i = ix; i < num_blocks_per_row; i += 2) {
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
@ -472,20 +478,20 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
float4 s = {0.f, 0.f, 0.f, 0.f};
float4 s = (float4)(0.f);
float smin = 0;
for (int l = 0; l < n; ++l) {
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp[tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
}
// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=block_size/2; s>0; s>>=1) {
for (int s=16; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
@ -805,10 +811,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
exit(1);
}
const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
"-DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
if(err < 0) {
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);