Improve dequant shaders, add fast q4_0 dequant
This commit is contained in:
parent
4b7b38bef5
commit
b4172ca29f
3 changed files with 6394 additions and 6860 deletions
13009
ggml-vulkan-shaders.hpp
13009
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
|
@ -941,18 +941,18 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
// dequant shaders
|
// dequant shaders
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), { 64, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 5 * sizeof(uint32_t), { 64, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), { 8 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||||
|
|
||||||
// get_rows
|
// get_rows
|
||||||
ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -2292,15 +2292,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
|
||||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||||
|
|
||||||
// Allocate descriptor sets
|
// Allocate descriptor sets
|
||||||
ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13);
|
ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13);
|
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
|
||||||
}
|
}
|
||||||
if (qy_needs_dequant) {
|
if (qy_needs_dequant) {
|
||||||
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
|
ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, 1);
|
||||||
}
|
}
|
||||||
if (split_k > 1) {
|
if (split_k > 1) {
|
||||||
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13);
|
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
|
@ -2313,9 +2313,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
|
||||||
}
|
}
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0) / 32) }; // TODO: replace with subgroup size
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
|
@ -2505,9 +2505,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
|
||||||
const uint64_t d_shader_offset = d_offset - d_buffer_offset;
|
const uint64_t d_shader_offset = d_offset - d_buffer_offset;
|
||||||
|
|
||||||
if (!y_non_contig && qy_needs_dequant) {
|
if (!y_non_contig && qy_needs_dequant) {
|
||||||
const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
|
const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
|
@ -3820,7 +3820,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||||
|
|
||||||
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
|
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
|
||||||
ggml_vk_ctx_begin(ctx, subctx);
|
ggml_vk_ctx_begin(ctx, subctx);
|
||||||
const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne };
|
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)(ne / 32) };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
|
||||||
ggml_vk_ctx_end(subctx);
|
ggml_vk_ctx_end(subctx);
|
||||||
|
|
||||||
|
|
|
@ -438,6 +438,15 @@ dequant_head = """#version 450
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : require
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint M;
|
||||||
|
uint K;
|
||||||
|
uint stride_a;
|
||||||
|
uint stride_b;
|
||||||
|
uint num_groups;
|
||||||
|
} p;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dequant_body = """
|
dequant_body = """
|
||||||
|
@ -446,33 +455,25 @@ layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int i = int(gl_GlobalInvocationID.x);
|
const uint i = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
// Transposed
|
// Transposed
|
||||||
const int row = i % (p.K / QUANT_K);
|
const uint row = i % (p.K / QUANT_K);
|
||||||
const int col = i / (p.K / QUANT_K);
|
const uint col = i / (p.K / QUANT_K);
|
||||||
|
|
||||||
if (row * QUANT_K >= p.K || col >= p.M) {
|
if (row * QUANT_K >= p.K || col >= p.M) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int stride_a = p.stride_a / QUANT_K;
|
const uint stride_a = p.stride_a / QUANT_K;
|
||||||
|
|
||||||
const int ib = col * stride_a + row;
|
const uint ib = col * stride_a + row;
|
||||||
|
|
||||||
const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||||
const int step = QUANT_R == 1 ? 2 : 1;
|
const uint step = QUANT_R == 1 ? 2 : 1;
|
||||||
|
|
||||||
[[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
|
[[unroll]] for (uint iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
|
||||||
DEQUANT_FUNC
|
DEQUANT_FUNC
|
||||||
|
|
||||||
data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
|
data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
|
||||||
|
@ -481,6 +482,38 @@ void main() {
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
dequant_q4_0_body = """
|
||||||
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
const uint il = tid/8;
|
||||||
|
const uint ir = tid%8;
|
||||||
|
const uint ib = 8*i + ir;
|
||||||
|
if (ib >= p.num_groups) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint b_idx = 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const float dm = -8.0f * d;
|
||||||
|
|
||||||
|
const uint q_idx = 4*il;
|
||||||
|
|
||||||
|
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||||
|
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm);
|
||||||
|
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
# K-quants
|
# K-quants
|
||||||
dequant_q2_K_body = """
|
dequant_q2_K_body = """
|
||||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
@ -488,29 +521,21 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||||
if (i >= p.M * p.K / QUANT_K) {
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int tid = int(gl_LocalInvocationID.x);
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const int ip = tid / 32;
|
const uint ip = tid / 32;
|
||||||
const int il = tid - 32 * ip;
|
const uint il = tid - 32 * ip;
|
||||||
const int is = 8 * ip + il / 16;
|
const uint is = 8 * ip + il / 16;
|
||||||
|
|
||||||
const int y_idx = i * QUANT_K + 128 * ip + il;
|
const uint y_idx = i * QUANT_K + 128 * ip + il;
|
||||||
|
|
||||||
const int ql_idx = 32 * ip + il;
|
const uint ql_idx = 32 * ip + il;
|
||||||
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
||||||
|
|
||||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||||
|
@ -528,31 +553,23 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
|
||||||
if (i >= p.M * p.K / QUANT_K) {
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int r = int(gl_LocalInvocationID.x) / 4;
|
const uint r = gl_LocalInvocationID.x / 4;
|
||||||
const int tid = r / 2;
|
const uint tid = r / 2;
|
||||||
const int is0 = r % 2;
|
const uint is0 = r % 2;
|
||||||
const int l0 = 16 * is0 + 4 * (int(gl_LocalInvocationID.x) % 4);
|
const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
|
||||||
const int n = tid / 4;
|
const uint n = tid / 4;
|
||||||
const int j = tid - 4*n;
|
const uint j = tid - 4*n;
|
||||||
|
|
||||||
const uint8_t m = uint8_t(1 << (4*n + j));
|
const uint8_t m = uint8_t(1 << (4*n + j));
|
||||||
const int is = 8*n + 2*j + is0;
|
const uint is = 8*n + 2*j + is0;
|
||||||
const int shift = 2*j;
|
const uint shift = 2*j;
|
||||||
|
|
||||||
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
|
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
|
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
|
||||||
|
@ -561,10 +578,10 @@ void main() {
|
||||||
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
|
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
|
||||||
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
|
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
|
||||||
|
|
||||||
const int y_idx = i * QUANT_K + 128 * n + 32 * j;
|
const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
|
||||||
const int qs_idx = 32*n;
|
const uint qs_idx = 32*n;
|
||||||
|
|
||||||
for (int l = l0; l < l0 + 4; ++l) {
|
for (uint l = l0; l < l0 + 4; ++l) {
|
||||||
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
|
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -576,32 +593,24 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||||
if (i >= p.M * p.K / QUANT_K) {
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int tid = int(gl_LocalInvocationID.x);
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const int il = tid / 8;
|
const uint il = tid / 8;
|
||||||
const int ir = tid % 8;
|
const uint ir = tid % 8;
|
||||||
const int is = 2 * il;
|
const uint is = 2 * il;
|
||||||
const int n = 4;
|
const uint n = 4;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
||||||
|
|
||||||
const int y_idx = i * QUANT_K + 64 * il + n * ir;
|
const uint y_idx = i * QUANT_K + 64 * il + n * ir;
|
||||||
const int qs_idx = 32*il + n * ir;
|
const uint qs_idx = 32*il + n * ir;
|
||||||
|
|
||||||
uint8_t sc;
|
uint8_t sc;
|
||||||
uint8_t m;
|
uint8_t m;
|
||||||
|
@ -625,7 +634,7 @@ void main() {
|
||||||
const FLOAT_TYPE d2 = dall * sc;
|
const FLOAT_TYPE d2 = dall * sc;
|
||||||
const FLOAT_TYPE m2 = dmin * m;
|
const FLOAT_TYPE m2 = dmin * m;
|
||||||
|
|
||||||
[[unroll]] for (int l = 0; l < n; ++l) {
|
[[unroll]] for (uint l = 0; l < n; ++l) {
|
||||||
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
|
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
|
||||||
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
|
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
|
||||||
}
|
}
|
||||||
|
@ -638,32 +647,24 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||||
if (i >= p.M * p.K / QUANT_K) {
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int tid = int(gl_LocalInvocationID.x);
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const int il = tid / 16;
|
const uint il = tid / 16;
|
||||||
const int ir = tid % 16;
|
const uint ir = tid % 16;
|
||||||
const int is = 2 * il;
|
const uint is = 2 * il;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
||||||
|
|
||||||
const int y_idx = i * QUANT_K + 64 * il + 2 * ir;
|
const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
|
||||||
const int qs_idx = 32*il + 2 * ir;
|
const uint qs_idx = 32*il + 2 * ir;
|
||||||
const int qh_idx = 2 * ir;
|
const uint qh_idx = 2 * ir;
|
||||||
|
|
||||||
uint8_t sc;
|
uint8_t sc;
|
||||||
uint8_t m;
|
uint8_t m;
|
||||||
|
@ -702,28 +703,20 @@ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
|
||||||
{
|
|
||||||
int M;
|
|
||||||
int K;
|
|
||||||
int stride_a;
|
|
||||||
int stride_b;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||||
const int i = int(gl_WorkGroupID.x * 256 + wgy);
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||||
if (i >= p.M * p.K / QUANT_K) {
|
if (i >= p.M * p.K / QUANT_K) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int tid = int(gl_LocalInvocationID.x);
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const int ip = tid / 32;
|
const uint ip = tid / 32;
|
||||||
const int il = tid - 32 * ip;
|
const uint il = tid - 32 * ip;
|
||||||
const int is = 8 * ip + il / 16;
|
const uint is = 8 * ip + il / 16;
|
||||||
|
|
||||||
const int y_idx = i * QUANT_K + 128 * ip + il;
|
const uint y_idx = i * QUANT_K + 128 * ip + il;
|
||||||
|
|
||||||
const int ql_idx = 64 * ip + il;
|
const uint ql_idx = 64 * ip + il;
|
||||||
const uint8_t qh = data_a[i].qh[32 * ip + il];
|
const uint8_t qh = data_a[i].qh[32 * ip + il];
|
||||||
|
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
|
||||||
|
@ -2208,7 +2201,7 @@ async def main():
|
||||||
if i == GGML_TYPE_F16:
|
if i == GGML_TYPE_F16:
|
||||||
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
|
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
|
||||||
elif i == GGML_TYPE_Q4_0:
|
elif i == GGML_TYPE_Q4_0:
|
||||||
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body))
|
stream.extend((shader_q4_0_defines, dequant_q4_0_body))
|
||||||
elif i == GGML_TYPE_Q4_1:
|
elif i == GGML_TYPE_Q4_1:
|
||||||
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
|
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
|
||||||
elif i == GGML_TYPE_Q5_0:
|
elif i == GGML_TYPE_Q5_0:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue