Add q6_k support
This commit is contained in:
parent
b6591b5cb4
commit
42bfa889a6
2 changed files with 169 additions and 6 deletions
|
@ -85,6 +85,20 @@ struct block_q8_0
|
|||
#define A_TYPE block_q8_0
|
||||
)";
|
||||
|
||||
const std::string shader_q6_K_defines = R"(
|
||||
#define QUANT_K 256
|
||||
|
||||
struct block_q6_K
|
||||
{
|
||||
uint8_t ql[QUANT_K/2];
|
||||
uint8_t qh[QUANT_K/4];
|
||||
int8_t scales[QUANT_K/16];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
#define A_TYPE block_q6_K
|
||||
)";
|
||||
|
||||
// Dequant functions
|
||||
const std::string shader_f16_dequant_func = R"(
|
||||
#define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]);
|
||||
|
@ -438,6 +452,42 @@ void main() {
|
|||
}
|
||||
)";
|
||||
|
||||
// K-quants
|
||||
const std::string dequant_q6_K_body = R"(
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
|
||||
layout (binding = 1) writeonly buffer D { D_TYPE y[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int i = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
const int ip = tid / 32;
|
||||
const int il = tid - 32 * ip;
|
||||
const int is = 8 * ip + il / 16;
|
||||
|
||||
const int y_idx = i * QUANT_K + 128 * ip + il;
|
||||
|
||||
const int ql_idx = 64 * ip + il;
|
||||
const uint8_t qh = x[i].qh[32 * ip + il];
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
|
||||
|
||||
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||
}
|
||||
)";
|
||||
|
||||
// Mul Mat Vec
|
||||
const std::string mul_mat_vec_head = R"(
|
||||
#version 450
|
||||
|
@ -497,6 +547,90 @@ void main() {
|
|||
}
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_q6_K_body = R"(
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
|
||||
layout (binding = 1) readonly buffer B { B_TYPE y[]; };
|
||||
layout (binding = 2) writeonly buffer D { D_TYPE dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared FLOAT_TYPE tmp[32];
|
||||
|
||||
void main() {
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
|
||||
const int num_blocks_per_row = p.ncols / QUANT_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
|
||||
const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int v_in = tid - step*v_im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
const int l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
|
||||
const int ql_offset = 64*v_im + l0;
|
||||
const int qh_offset = 32*v_im + l0;
|
||||
const int s_offset = 8*v_im + is;
|
||||
const int y_offset = 128*v_im + l0;
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
const int y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
const FLOAT_TYPE d = x[ib0 + i].d;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#else
|
||||
FLOAT_TYPE sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
|
||||
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#endif
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = 16; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = D_TYPE(tmp[0]);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
// F16 to F32
|
||||
const std::string f32_to_f16_src = R"(
|
||||
#version 450
|
||||
|
|
|
@ -54,6 +54,12 @@
|
|||
|
||||
#define VK_NUM_TYPES 16
|
||||
|
||||
#ifndef K_QUANTS_PER_ITERATION
|
||||
#define K_QUANTS_PER_ITERATION 1
|
||||
#else
|
||||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
struct vk_buffer {
|
||||
|
@ -714,6 +720,9 @@ static inline bool ggml_vk_build_shader_type_defines(std::stringstream& stream,
|
|||
case GGML_TYPE_Q8_0:
|
||||
stream << shader_q8_0_defines << (compat ? shader_q8_0_dequant_func_compat : shader_q8_0_dequant_func);
|
||||
return true;
|
||||
case GGML_TYPE_Q6_K:
|
||||
stream << shader_q6_K_defines;
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -789,9 +798,20 @@ static void ggml_vk_generate_shaders() {
|
|||
continue;
|
||||
}
|
||||
|
||||
stream << dequant_body;
|
||||
int work_group_denom;
|
||||
|
||||
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||
switch ((ggml_type)i) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
stream << dequant_q6_K_body;
|
||||
work_group_denom = 64 * 4;
|
||||
break;
|
||||
default:
|
||||
stream << dequant_body;
|
||||
work_group_denom = 256 * 32;
|
||||
break;
|
||||
}
|
||||
|
||||
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
// mul mat vec
|
||||
|
@ -808,10 +828,17 @@ static void ggml_vk_generate_shaders() {
|
|||
continue;
|
||||
}
|
||||
|
||||
stream << mul_mat_vec_body;
|
||||
switch ((ggml_type)i) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
stream << mul_mat_vec_q6_K_body;
|
||||
break;
|
||||
default:
|
||||
stream << mul_mat_vec_body;
|
||||
break;
|
||||
}
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
// add
|
||||
|
@ -1049,6 +1076,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
@ -1068,6 +1096,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
|
@ -1205,7 +1234,7 @@ void ggml_vk_host_free(void* ptr) {
|
|||
}
|
||||
}
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "WARNING: to free pinned memory: memory not in map\n");
|
||||
fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue