Fix q2_k fast kernel
This commit is contained in:
parent
69fd31d18c
commit
069cbe530d
1 changed files with 22 additions and 15 deletions
|
@ -21,6 +21,12 @@
|
||||||
|
|
||||||
#define CL_DMMV_BLOCK_SIZE 32
|
#define CL_DMMV_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
#ifndef K_QUANTS_PER_ITERATION
|
||||||
|
#define K_QUANTS_PER_ITERATION 2
|
||||||
|
#else
|
||||||
|
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||||
static std::string program_source = MULTILINE_QUOTE(
|
static std::string program_source = MULTILINE_QUOTE(
|
||||||
|
|
||||||
|
@ -429,17 +435,18 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
const uint16_t kmask3 = 0xc0c0;
|
const uint16_t kmask3 = 0xc0c0;
|
||||||
|
|
||||||
const int block_size = get_local_size(0);
|
|
||||||
const int row = get_group_id(0);
|
const int row = get_group_id(0);
|
||||||
const int num_blocks_per_row = ncols / 256;
|
const int num_blocks_per_row = ncols / 256;
|
||||||
const int ib0 = row*num_blocks_per_row;
|
const int ib0 = row*num_blocks_per_row;
|
||||||
|
|
||||||
const int tid = get_local_id(0)/2; // 0...15
|
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
|
||||||
const int ix = get_local_id(0)%2;
|
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
const int il = tid/4; // 0...3
|
const int step = 8/K_QUANTS_PER_ITERATION;
|
||||||
const int ir = tid - 4*il;// 0...3
|
|
||||||
const int n = 4;
|
const int il = tid/step; // 0...3
|
||||||
|
const int ir = tid - step*il;// 0...3
|
||||||
|
const int n = 2*K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||||
const int in = il%2;
|
const int in = il%2;
|
||||||
|
@ -448,15 +455,14 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
|
||||||
const int q_offset = 32*im + l0;
|
const int q_offset = 32*im + l0;
|
||||||
const int y_offset = 64*im + l0;
|
const int y_offset = 64*im + l0;
|
||||||
|
|
||||||
|
|
||||||
uint16_t aux[4];
|
uint16_t aux[4];
|
||||||
const uint8_t * sc = (const uint8_t *)aux;
|
const uint8_t * sc = (const uint8_t *)aux;
|
||||||
|
|
||||||
const struct block_q4_K * x = xx + ib0;
|
const struct block_q4_K * x = xx + ib0;
|
||||||
|
|
||||||
tmp[tid] = 0;
|
tmp[16 * ix + tid] = 0;
|
||||||
|
|
||||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
|
|
||||||
const uint8_t * q1 = x[i].qs + q_offset;
|
const uint8_t * q1 = x[i].qs + q_offset;
|
||||||
const uint8_t * q2 = q1 + 64;
|
const uint8_t * q2 = q1 + 64;
|
||||||
|
@ -472,20 +478,20 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
|
||||||
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
||||||
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
||||||
|
|
||||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
float4 s = (float4)(0.f);
|
||||||
float smin = 0;
|
float smin = 0;
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
|
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
|
||||||
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
|
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
|
||||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||||
}
|
}
|
||||||
tmp[tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
for (int s=block_size/2; s>0; s>>=1) {
|
for (int s=16; s>0; s>>=1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
tmp[tid] += tmp[tid + s];
|
tmp[tid] += tmp[tid + s];
|
||||||
}
|
}
|
||||||
|
@ -805,10 +811,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
|
std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
|
||||||
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
|
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
|
||||||
|
"-DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
|
||||||
|
|
||||||
err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
|
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
|
||||||
if(err < 0) {
|
if(err < 0) {
|
||||||
|
|
||||||
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue