Add q6_k fast matmul kernel
This commit is contained in:
parent
34a4917984
commit
8d816d19d1
1 changed files with 79 additions and 3 deletions
|
@ -428,7 +428,7 @@ void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int i
|
|||
*result = sum;
|
||||
}
|
||||
|
||||
__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
||||
__kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
||||
|
||||
//to rename it later, just to test now
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
|
@ -466,7 +466,7 @@ __kernel void dequantize_mul_mat_vec_q4_K_fast(__global struct block_q4_K * xx,
|
|||
|
||||
const uint8_t * q1 = x[i].qs + q_offset;
|
||||
const uint8_t * q2 = q1 + 64;
|
||||
const float * y1 = y + i*QK_K + y_offset;
|
||||
const float * y1 = yy + i*QK_K + y_offset;
|
||||
const float * y2 = y1 + 128;
|
||||
|
||||
const float dall = vload_half(0, &x[i].d);
|
||||
|
@ -562,6 +562,82 @@ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int i
|
|||
|
||||
}
|
||||
|
||||
__kernel void dequantize_mul_mat_vec_q6_K_fast(__global struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
|
||||
|
||||
const int row = get_group_id(0);
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const struct block_q6_K * x = xx + ib0;
|
||||
|
||||
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
const int ql_offset = 64*im + l0;
|
||||
const int qh_offset = 32*im + l0;
|
||||
const int s_offset = 8*im + is;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
tmp[16 * ix + tid] = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * ql = x[i].ql + ql_offset;
|
||||
const uint8_t * qh = x[i].qh + qh_offset;
|
||||
const int8_t * s = x[i].scales + s_offset;
|
||||
|
||||
const float d = vload_half(0, &x[i].d);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
||||
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
||||
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
||||
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
||||
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
||||
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#else
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
||||
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
|
||||
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
|
||||
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
for (int s=16; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
);
|
||||
|
||||
|
||||
|
@ -1041,7 +1117,7 @@ void ggml_cl_init(void) {
|
|||
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
|
||||
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K_fast", &err), err));
|
||||
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
|
||||
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
|
||||
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K_fast", &err), err));
|
||||
|
||||
// mul kernel
|
||||
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue