Add q6_k support

This commit is contained in:
0cc4m 2023-09-29 23:12:34 +02:00
parent b6591b5cb4
commit 42bfa889a6
2 changed files with 169 additions and 6 deletions

View file

@ -85,6 +85,20 @@ struct block_q8_0
#define A_TYPE block_q8_0
)";
const std::string shader_q6_K_defines = R"(
#define QUANT_K 256
struct block_q6_K
{
uint8_t ql[QUANT_K/2];
uint8_t qh[QUANT_K/4];
int8_t scales[QUANT_K/16];
float16_t d;
};
#define A_TYPE block_q6_K
)";
// Dequant functions
const std::string shader_f16_dequant_func = R"(
#define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]);
@ -438,6 +452,42 @@ void main() {
}
)";
// K-quants
const std::string dequant_q6_K_body = R"(
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) writeonly buffer D { D_TYPE y[]; };
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() {
const int i = int(gl_WorkGroupID.x);
const int tid = int(gl_LocalInvocationID.x);
const int ip = tid / 32;
const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16;
const int y_idx = i * QUANT_K + 128 * ip + il;
const int ql_idx = 64 * ip + il;
const uint8_t qh = x[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(x[i].d);
y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
}
)";
// Mul Mat Vec
const std::string mul_mat_vec_head = R"(
#version 450
@ -497,6 +547,90 @@ void main() {
}
)";
const std::string mul_mat_vec_q6_K_body = R"(
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) readonly buffer B { B_TYPE y[]; };
layout (binding = 2) writeonly buffer D { D_TYPE dst[]; };
layout (push_constant) uniform parameter
{
int ncols;
} p;
shared FLOAT_TYPE tmp[32];
void main() {
const int row = int(gl_WorkGroupID.x);
const int num_blocks_per_row = p.ncols / QUANT_K;
const int ib0 = row*num_blocks_per_row;
const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int v_in = tid - step*v_im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
const int is = 0;
#else
const int l0 = 4 * v_in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*v_im + l0;
const int qh_offset = 32*v_im + l0;
const int s_offset = 8*v_im + is;
const int y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const int y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = x[ib0 + i].d;
#if K_QUANTS_PER_ITERATION == 1
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
tmp[16 * ix + tid] += sum;
#else
FLOAT_TYPE sum = 0;
for (int l = 0; l < 4; ++l) {
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
}
tmp[16 * ix + tid] += sum;
#endif
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
if (tid == 0) {
dst[row] = D_TYPE(tmp[0]);
}
}
)";
// F16 to F32
const std::string f32_to_f16_src = R"(
#version 450

View file

@ -54,6 +54,12 @@
#define VK_NUM_TYPES 16
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
struct vk_buffer {
@ -714,6 +720,9 @@ static inline bool ggml_vk_build_shader_type_defines(std::stringstream& stream,
case GGML_TYPE_Q8_0:
stream << shader_q8_0_defines << (compat ? shader_q8_0_dequant_func_compat : shader_q8_0_dequant_func);
return true;
case GGML_TYPE_Q6_K:
stream << shader_q6_K_defines;
return true;
default:
return false;
}
@ -789,9 +798,20 @@ static void ggml_vk_generate_shaders() {
continue;
}
stream << dequant_body;
int work_group_denom;
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
switch ((ggml_type)i) {
case GGML_TYPE_Q6_K:
stream << dequant_q6_K_body;
work_group_denom = 64 * 4;
break;
default:
stream << dequant_body;
work_group_denom = 256 * 32;
break;
}
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1);
}
// mul mat vec
@ -808,10 +828,17 @@ static void ggml_vk_generate_shaders() {
continue;
}
stream << mul_mat_vec_body;
switch ((ggml_type)i) {
case GGML_TYPE_Q6_K:
stream << mul_mat_vec_q6_K_body;
break;
default:
stream << mul_mat_vec_body;
break;
}
vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
}
// add
@ -1049,6 +1076,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
@ -1068,6 +1096,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
@ -1205,7 +1234,7 @@ void ggml_vk_host_free(void* ptr) {
}
}
if (buf == nullptr) {
fprintf(stderr, "WARNING: to free pinned memory: memory not in map\n");
fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
return;
}