Improve dequant shaders, add fast q4_0 dequant

This commit is contained in:
0cc4m 2024-02-09 20:05:47 +01:00
parent 4b7b38bef5
commit b4172ca29f
3 changed files with 6394 additions and 6860 deletions

File diff suppressed because it is too large Load diff

View file

@ -941,18 +941,18 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
// dequant shaders // dequant shaders
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), { 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 5 * sizeof(uint32_t), { 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), { 8 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
// get_rows // get_rows
ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -2292,15 +2292,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
// Allocate descriptor sets // Allocate descriptor sets
ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13); ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
if (qx_needs_dequant) { if (qx_needs_dequant) {
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13); ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
} }
if (qy_needs_dequant) { if (qy_needs_dequant) {
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13); ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, 1);
} }
if (split_k > 1) { if (split_k > 1) {
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13); ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, 1);
} }
if (x_non_contig) { if (x_non_contig) {
@ -2313,9 +2313,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
} }
if (qx_needs_dequant) { if (qx_needs_dequant) {
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 }; const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0) / 32) }; // TODO: replace with subgroup size
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
} }
} }
if (y_non_contig) { if (y_non_contig) {
@ -2505,9 +2505,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
const uint64_t d_shader_offset = d_offset - d_buffer_offset; const uint64_t d_shader_offset = d_offset - d_buffer_offset;
if (!y_non_contig && qy_needs_dequant) { if (!y_non_contig && qy_needs_dequant) {
const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 }; const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1}); ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
} }
// compute // compute
@ -3820,7 +3820,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue); vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
ggml_vk_ctx_begin(ctx, subctx); ggml_vk_ctx_begin(ctx, subctx);
const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne }; const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)(ne / 32) };
ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);

View file

@ -438,6 +438,15 @@ dequant_head = """#version 450
#extension GL_EXT_control_flow_attributes : require #extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint M;
uint K;
uint stride_a;
uint stride_b;
uint num_groups;
} p;
""" """
dequant_body = """ dequant_body = """
@ -446,33 +455,25 @@ layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
const int i = int(gl_GlobalInvocationID.x); const uint i = gl_GlobalInvocationID.x;
// Transposed // Transposed
const int row = i % (p.K / QUANT_K); const uint row = i % (p.K / QUANT_K);
const int col = i / (p.K / QUANT_K); const uint col = i / (p.K / QUANT_K);
if (row * QUANT_K >= p.K || col >= p.M) { if (row * QUANT_K >= p.K || col >= p.M) {
return; return;
} }
const int stride_a = p.stride_a / QUANT_K; const uint stride_a = p.stride_a / QUANT_K;
const int ib = col * stride_a + row; const uint ib = col * stride_a + row;
const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
const int step = QUANT_R == 1 ? 2 : 1; const uint step = QUANT_R == 1 ? 2 : 1;
[[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) { [[unroll]] for (uint iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
DEQUANT_FUNC DEQUANT_FUNC
data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x); data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
@ -481,6 +482,38 @@ void main() {
} }
""" """
dequant_q4_0_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x;
// assume 32 threads
const uint tid = gl_LocalInvocationID.x;
const uint il = tid/8;
const uint ir = tid%8;
const uint ib = 8*i + ir;
if (ib >= p.num_groups) {
return;
}
const uint b_idx = 256*i + 32*ir + 4*il;
const float d = float(data_a[ib].d);
const float dm = -8.0f * d;
const uint q_idx = 4*il;
[[unroll]] for (uint l = 0; l < 4; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm);
}
}
"""
# K-quants # K-quants
dequant_q2_K_body = """ dequant_q2_K_body = """
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
@ -488,29 +521,21 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy); const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) { if (i >= p.M * p.K / QUANT_K) {
return; return;
} }
const int tid = int(gl_LocalInvocationID.x); const uint tid = gl_LocalInvocationID.x;
const int ip = tid / 32; const uint ip = tid / 32;
const int il = tid - 32 * ip; const uint il = tid - 32 * ip;
const int is = 8 * ip + il / 16; const uint is = 8 * ip + il / 16;
const int y_idx = i * QUANT_K + 128 * ip + il; const uint y_idx = i * QUANT_K + 128 * ip + il;
const int ql_idx = 32 * ip + il; const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il]; const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
@ -528,31 +553,23 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy); const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) { if (i >= p.M * p.K / QUANT_K) {
return; return;
} }
const int r = int(gl_LocalInvocationID.x) / 4; const uint r = gl_LocalInvocationID.x / 4;
const int tid = r / 2; const uint tid = r / 2;
const int is0 = r % 2; const uint is0 = r % 2;
const int l0 = 16 * is0 + 4 * (int(gl_LocalInvocationID.x) % 4); const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
const int n = tid / 4; const uint n = tid / 4;
const int j = tid - 4*n; const uint j = tid - 4*n;
const uint8_t m = uint8_t(1 << (4*n + j)); const uint8_t m = uint8_t(1 << (4*n + j));
const int is = 8*n + 2*j + is0; const uint is = 8*n + 2*j + is0;
const int shift = 2*j; const uint shift = 2*j;
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
@ -561,10 +578,10 @@ void main() {
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
const int y_idx = i * QUANT_K + 128 * n + 32 * j; const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
const int qs_idx = 32*n; const uint qs_idx = 32*n;
for (int l = l0; l < l0 + 4; ++l) { for (uint l = l0; l < l0 + 4; ++l) {
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
} }
} }
@ -576,32 +593,24 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy); const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) { if (i >= p.M * p.K / QUANT_K) {
return; return;
} }
const int tid = int(gl_LocalInvocationID.x); const uint tid = gl_LocalInvocationID.x;
const int il = tid / 8; const uint il = tid / 8;
const int ir = tid % 8; const uint ir = tid % 8;
const int is = 2 * il; const uint is = 2 * il;
const int n = 4; const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
const int y_idx = i * QUANT_K + 64 * il + n * ir; const uint y_idx = i * QUANT_K + 64 * il + n * ir;
const int qs_idx = 32*il + n * ir; const uint qs_idx = 32*il + n * ir;
uint8_t sc; uint8_t sc;
uint8_t m; uint8_t m;
@ -625,7 +634,7 @@ void main() {
const FLOAT_TYPE d2 = dall * sc; const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * m; const FLOAT_TYPE m2 = dmin * m;
[[unroll]] for (int l = 0; l < n; ++l) { [[unroll]] for (uint l = 0; l < n; ++l) {
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1); data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2); data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
} }
@ -638,32 +647,24 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy); const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) { if (i >= p.M * p.K / QUANT_K) {
return; return;
} }
const int tid = int(gl_LocalInvocationID.x); const uint tid = gl_LocalInvocationID.x;
const int il = tid / 16; const uint il = tid / 16;
const int ir = tid % 16; const uint ir = tid % 16;
const int is = 2 * il; const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
const int y_idx = i * QUANT_K + 64 * il + 2 * ir; const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
const int qs_idx = 32*il + 2 * ir; const uint qs_idx = 32*il + 2 * ir;
const int qh_idx = 2 * ir; const uint qh_idx = 2 * ir;
uint8_t sc; uint8_t sc;
uint8_t m; uint8_t m;
@ -702,28 +703,20 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() { void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy); const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.M * p.K / QUANT_K) { if (i >= p.M * p.K / QUANT_K) {
return; return;
} }
const int tid = int(gl_LocalInvocationID.x); const uint tid = gl_LocalInvocationID.x;
const int ip = tid / 32; const uint ip = tid / 32;
const int il = tid - 32 * ip; const uint il = tid - 32 * ip;
const int is = 8 * ip + il / 16; const uint is = 8 * ip + il / 16;
const int y_idx = i * QUANT_K + 128 * ip + il; const uint y_idx = i * QUANT_K + 128 * ip + il;
const int ql_idx = 64 * ip + il; const uint ql_idx = 64 * ip + il;
const uint8_t qh = data_a[i].qh[32 * ip + il]; const uint8_t qh = data_a[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
@ -2208,7 +2201,7 @@ async def main():
if i == GGML_TYPE_F16: if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body)) stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
elif i == GGML_TYPE_Q4_0: elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body)) stream.extend((shader_q4_0_defines, dequant_q4_0_body))
elif i == GGML_TYPE_Q4_1: elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body)) stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
elif i == GGML_TYPE_Q5_0: elif i == GGML_TYPE_Q5_0: