Add q4_k support
This commit is contained in:
parent
a861879256
commit
4a97d2d1ec
2 changed files with 201 additions and 1 deletions
|
@ -665,6 +665,7 @@ static inline bool ggml_vk_build_shader(ggml_type type) {
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -960,6 +961,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -982,6 +984,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -116,6 +116,18 @@ struct block_q3_K
|
||||||
|
|
||||||
#define A_TYPE block_q3_K
|
#define A_TYPE block_q3_K
|
||||||
"""
|
"""
|
||||||
|
shader_q4_K_defines = """
|
||||||
|
#define QUANT_K 256
|
||||||
|
|
||||||
|
struct block_q4_K
|
||||||
|
{
|
||||||
|
f16vec2 d;
|
||||||
|
uint8_t scales[3*QUANT_K/64];
|
||||||
|
uint8_t qs[QUANT_K/2];
|
||||||
|
};
|
||||||
|
|
||||||
|
#define A_TYPE block_q4_K
|
||||||
|
"""
|
||||||
shader_q6_K_defines = """
|
shader_q6_K_defines = """
|
||||||
#define QUANT_K 256
|
#define QUANT_K 256
|
||||||
|
|
||||||
|
@ -568,6 +580,68 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
dequant_q4_K_body = """
|
||||||
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE x[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE y[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
||||||
|
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
||||||
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int tid = int(gl_LocalInvocationID.x);
|
||||||
|
const int il = tid / 8;
|
||||||
|
const int ir = tid % 8;
|
||||||
|
const int is = 2 * il;
|
||||||
|
const int n = 4;
|
||||||
|
|
||||||
|
const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x);
|
||||||
|
const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y);
|
||||||
|
|
||||||
|
const int y_idx = i * QUANT_K + 64 * il + n * ir;
|
||||||
|
const int qs_idx = 32*il + n * ir;
|
||||||
|
|
||||||
|
uint8_t sc;
|
||||||
|
uint8_t m;
|
||||||
|
if (is < 4) {
|
||||||
|
sc = uint8_t(x[i].scales[is] & 63);
|
||||||
|
m = uint8_t(x[i].scales[is + 4] & 63);
|
||||||
|
} else {
|
||||||
|
sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4));
|
||||||
|
m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4));
|
||||||
|
}
|
||||||
|
const FLOAT_TYPE d1 = dall * sc;
|
||||||
|
const FLOAT_TYPE m1 = dmin * m;
|
||||||
|
|
||||||
|
if (is < 4) {
|
||||||
|
sc = uint8_t(x[i].scales[is + 1] & 63);
|
||||||
|
m = uint8_t(x[i].scales[is + 5] & 63);
|
||||||
|
} else {
|
||||||
|
sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4));
|
||||||
|
m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4));
|
||||||
|
}
|
||||||
|
const FLOAT_TYPE d2 = dall * sc;
|
||||||
|
const FLOAT_TYPE m2 = dmin * m;
|
||||||
|
|
||||||
|
for (int l = 0; l < n; ++l) {
|
||||||
|
y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1);
|
||||||
|
y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
dequant_q6_K_body = """
|
dequant_q6_K_body = """
|
||||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
@ -814,6 +888,125 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
mul_mat_vec_q4_K_body = """
|
||||||
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE x[];};
|
||||||
|
layout (binding = 1) readonly buffer B {B_TYPE y[];};
|
||||||
|
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int ncols;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int row = int(gl_WorkGroupID.x);
|
||||||
|
|
||||||
|
const int num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
|
const int ib0 = row*num_blocks_per_row;
|
||||||
|
|
||||||
|
const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||||
|
const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||||
|
|
||||||
|
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||||
|
|
||||||
|
const int il = tid/step; // 0...3
|
||||||
|
const int ir = tid - step*il; // 0...7 or 0...3
|
||||||
|
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
|
||||||
|
|
||||||
|
const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||||
|
const int v_in = il % 2;
|
||||||
|
|
||||||
|
const int l0 = n * (2 * ir + v_in); // 0...15
|
||||||
|
const int q_offset = 32*v_im + l0;
|
||||||
|
const int y_offset = 64*v_im + l0;
|
||||||
|
|
||||||
|
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||||
|
|
||||||
|
[[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
|
const int y1_idx = i * QUANT_K + y_offset;
|
||||||
|
const int y2_idx = y1_idx + 128;
|
||||||
|
|
||||||
|
const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x);
|
||||||
|
const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y);
|
||||||
|
|
||||||
|
const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f);
|
||||||
|
const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
|
||||||
|
const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
|
||||||
|
const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
|
||||||
|
const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
|
||||||
|
const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
|
||||||
|
const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
|
||||||
|
const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
|
||||||
|
|
||||||
|
#if K_QUANTS_PER_ITERATION == 2
|
||||||
|
const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf);
|
||||||
|
const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf);
|
||||||
|
const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 2] & 0xf);
|
||||||
|
const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 3] & 0xf);
|
||||||
|
const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4);
|
||||||
|
const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4);
|
||||||
|
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 2] >> 4);
|
||||||
|
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 3] >> 4);
|
||||||
|
const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf);
|
||||||
|
const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf);
|
||||||
|
const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 66] & 0xf);
|
||||||
|
const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 67] & 0xf);
|
||||||
|
const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
|
||||||
|
const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
|
||||||
|
const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 66] >> 4);
|
||||||
|
const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 67] >> 4);
|
||||||
|
|
||||||
|
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3);
|
||||||
|
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7);
|
||||||
|
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11);
|
||||||
|
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15);
|
||||||
|
const FLOAT_TYPE smin = FLOAT_TYPE(
|
||||||
|
y[y1_idx ] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx ] * sc6 + y[y2_idx + 32] * sc7
|
||||||
|
+ y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7
|
||||||
|
+ y[y1_idx + 2] * sc2 + y[y1_idx + 34] * sc3 + y[y2_idx + 2] * sc6 + y[y2_idx + 34] * sc7
|
||||||
|
+ y[y1_idx + 3] * sc2 + y[y1_idx + 35] * sc3 + y[y2_idx + 3] * sc6 + y[y2_idx + 35] * sc7
|
||||||
|
);
|
||||||
|
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
|
||||||
|
#else
|
||||||
|
const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf);
|
||||||
|
const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf);
|
||||||
|
const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4);
|
||||||
|
const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4);
|
||||||
|
const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf);
|
||||||
|
const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf);
|
||||||
|
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
|
||||||
|
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
|
||||||
|
|
||||||
|
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1);
|
||||||
|
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3);
|
||||||
|
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_4 + y[y2_idx + 1] * q4_5);
|
||||||
|
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7);
|
||||||
|
const FLOAT_TYPE smin = FLOAT_TYPE(
|
||||||
|
y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7
|
||||||
|
+ y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7
|
||||||
|
);
|
||||||
|
|
||||||
|
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(x[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(x[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((x[ib0 + i].scales[v_im + 4] & 0x0f) | ((x[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((x[ib0 + i].scales[v_im + 5] & 0x0f) | ((x[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = 16; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
tmp[tid] += tmp[tid + s];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
dst[row] = D_TYPE(tmp[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
mul_mat_vec_q6_K_body = """
|
mul_mat_vec_q6_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
@ -1068,7 +1261,7 @@ type_names = {
|
||||||
GGML_TYPE_Q8_K: "q8_K",
|
GGML_TYPE_Q8_K: "q8_K",
|
||||||
}
|
}
|
||||||
|
|
||||||
K_QUANTS_PER_ITERATION = 1
|
K_QUANTS_PER_ITERATION = 2
|
||||||
|
|
||||||
|
|
||||||
async def string_to_spv_file(name, code, defines, fp16):
|
async def string_to_spv_file(name, code, defines, fp16):
|
||||||
|
@ -1184,6 +1377,8 @@ async def main():
|
||||||
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
|
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
|
||||||
elif i == GGML_TYPE_Q3_K:
|
elif i == GGML_TYPE_Q3_K:
|
||||||
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
|
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
|
||||||
|
elif i == GGML_TYPE_Q4_K:
|
||||||
|
stream.extend((shader_q4_K_defines, dequant_q4_K_body))
|
||||||
elif i == GGML_TYPE_Q6_K:
|
elif i == GGML_TYPE_Q6_K:
|
||||||
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
|
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
|
||||||
else:
|
else:
|
||||||
|
@ -1212,6 +1407,8 @@ async def main():
|
||||||
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
|
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
|
||||||
elif i == GGML_TYPE_Q3_K:
|
elif i == GGML_TYPE_Q3_K:
|
||||||
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
|
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
|
||||||
|
elif i == GGML_TYPE_Q4_K:
|
||||||
|
stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body))
|
||||||
elif i == GGML_TYPE_Q6_K:
|
elif i == GGML_TYPE_Q6_K:
|
||||||
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
|
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue