Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0
This commit is contained in:
parent
5169f928c7
commit
2a0cf851d4
3 changed files with 3011 additions and 4789 deletions
File diff suppressed because it is too large
Load diff
|
@ -992,13 +992,12 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
// dequant shaders
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 5 * sizeof(uint32_t), { 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
|
@ -3436,7 +3435,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|||
}
|
||||
}
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it);
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
|
@ -3482,7 +3481,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|||
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_ctx_begin(ctx, subctx);
|
||||
ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
|
||||
ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
|
||||
ggml_vk_ctx_end(subctx);
|
||||
}
|
||||
|
||||
|
@ -3600,7 +3599,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|||
ggml_vk_destroy_buffer(d_Y);
|
||||
ggml_vk_destroy_buffer(d_D);
|
||||
|
||||
ggml_pipeline_cleanup(*p);
|
||||
ggml_pipeline_cleanup(p);
|
||||
ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
|
||||
|
||||
free(x);
|
||||
|
@ -3836,9 +3835,12 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|||
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
vk_pipeline& p = ctx->device->pipeline_dequant[quant];
|
||||
vk_pipeline p = ctx->device->pipeline_dequant[quant];
|
||||
|
||||
switch(quant) {
|
||||
case GGML_TYPE_F32:
|
||||
memcpy(qx, x, sizeof(float) * ne);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
|
@ -3894,12 +3896,34 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|||
double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||
ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
|
||||
|
||||
int first_err = -1;
|
||||
|
||||
double avg_err = 0.0;
|
||||
for (size_t i = 0; i < ne; i++) {
|
||||
avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
|
||||
double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
|
||||
avg_err += error;
|
||||
|
||||
if (first_err < 0 && error > 0.05) {
|
||||
first_err = i;
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl;
|
||||
avg_err /= ne;
|
||||
|
||||
std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
|
||||
|
||||
if (avg_err > 0.1) {
|
||||
std::cerr << "first_error = " << first_err << std::endl;
|
||||
std::cerr << "Actual result: " << std::endl << std::endl;
|
||||
for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
|
||||
std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
|
||||
}
|
||||
std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
|
||||
for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
|
||||
std::cerr << x[i] << ", ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
|
||||
ggml_vk_destroy_buffer(x_buf);
|
||||
ggml_vk_destroy_buffer(qx_buf);
|
||||
|
@ -4062,6 +4086,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_test_transfer(ctx, 8192 * 1000, false);
|
||||
ggml_vk_test_transfer(ctx, 8192 * 1000, true);
|
||||
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
|
||||
|
|
|
@ -71,7 +71,7 @@ struct block_q5_1
|
|||
{
|
||||
float16_t d;
|
||||
float16_t m;
|
||||
uint qh;
|
||||
uint16_t qh[2];
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
|
@ -187,7 +187,8 @@ v = (v - 16.0f) * d;
|
|||
shader_q5_1_dequant_func = """
|
||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
||||
const float m = float(data_a[ib].m); \
|
||||
const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
|
||||
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | uint(data_a[ib].qh[0]); \
|
||||
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
|
||||
const uint vui = uint(data_a[ib].qs[iqs]); \
|
||||
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
|
||||
v = v*d + m;
|
||||
|
@ -449,35 +450,21 @@ layout (push_constant) uniform parameter
|
|||
} p;
|
||||
"""
|
||||
|
||||
dequant_body = """
|
||||
dequant_f32_body = """
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.x;
|
||||
const uint i = gl_GlobalInvocationID.x * 16;
|
||||
|
||||
// Transposed
|
||||
const uint row = i % (p.K / QUANT_K);
|
||||
const uint col = i / (p.K / QUANT_K);
|
||||
|
||||
if (row * QUANT_K >= p.K || col >= p.M) {
|
||||
if (i >= p.nel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint stride_a = p.stride_a / QUANT_K;
|
||||
|
||||
const uint ib = col * stride_a + row;
|
||||
|
||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
const uint step = QUANT_R == 1 ? 2 : 1;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
|
||||
DEQUANT_FUNC
|
||||
|
||||
data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
|
||||
data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y);
|
||||
[[unroll]] for (uint l = 0; l < 16; l++) {
|
||||
data_b[i + l] = D_TYPE(data_a[i + l]);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
@ -513,6 +500,134 @@ void main() {
|
|||
}
|
||||
"""
|
||||
|
||||
dequant_q4_1_body = """
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
const uint ir = tid%32;
|
||||
const uint ib = 32*i + ir;
|
||||
if (ib >= p.nel / 32) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint b_idx = 1024*i + 32*ir + 8*il;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float m = float(data_a[ib].m);
|
||||
|
||||
const uint q_idx = 8*il;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
dequant_q5_0_body = """
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
const uint ir = tid%32;
|
||||
const uint ib = 32*i + ir;
|
||||
if (ib >= p.nel / 32) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint b_idx = 1024*i + 32*ir + 8*il;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
|
||||
|
||||
const uint q_idx = 8*il;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
const uint iqs = q_idx + l;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
dequant_q5_1_body = """
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
const uint ir = tid%32;
|
||||
const uint ib = 32*i + ir;
|
||||
if (ib >= p.nel / 32) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint b_idx = 1024*i + 32*ir + 8*il;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float m = float(data_a[ib].m);
|
||||
const uint qh = uint(data_a[ib].qh[1]) << 16 | uint(data_a[ib].qh[0]);
|
||||
|
||||
const uint q_idx = 8*il;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
const uint iqs = q_idx + l;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
dequant_q8_0_body = """
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
const uint ir = tid%32;
|
||||
const uint ib = 32*i + ir;
|
||||
if (ib >= p.nel / 32) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint b_idx = 1024*i + 32*ir + 16*il;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
|
||||
const uint q_idx = 16*il;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 16; l += 2) {
|
||||
data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
|
||||
data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# K-quants
|
||||
dequant_q2_K_body = """
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
@ -1380,34 +1495,6 @@ void main() {
|
|||
}
|
||||
"""
|
||||
|
||||
# F16 to F32
|
||||
f32_to_f16_src = """#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float16_t data_b[];};
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int row = int(gl_GlobalInvocationID.x % p.K);
|
||||
const int col = int(gl_GlobalInvocationID.x / p.K);
|
||||
|
||||
if (row < p.K && col < p.M) {
|
||||
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
generic_head = """
|
||||
#version 450
|
||||
|
||||
|
@ -2154,18 +2241,18 @@ async def main():
|
|||
|
||||
stream.extend((dequant_head, shader_int8_ext, shader_f32))
|
||||
|
||||
if i == GGML_TYPE_F16:
|
||||
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
|
||||
if i == GGML_TYPE_F32:
|
||||
stream.append(dequant_f32_body)
|
||||
elif i == GGML_TYPE_Q4_0:
|
||||
stream.extend((shader_q4_0_defines, dequant_q4_0_body))
|
||||
elif i == GGML_TYPE_Q4_1:
|
||||
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
|
||||
stream.extend((shader_q4_1_defines, dequant_q4_1_body))
|
||||
elif i == GGML_TYPE_Q5_0:
|
||||
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body))
|
||||
stream.extend((shader_q5_0_defines, dequant_q5_0_body))
|
||||
elif i == GGML_TYPE_Q5_1:
|
||||
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body))
|
||||
stream.extend((shader_q5_1_defines, dequant_q5_1_body))
|
||||
elif i == GGML_TYPE_Q8_0:
|
||||
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body))
|
||||
stream.extend((shader_q8_0_defines, dequant_q8_0_body))
|
||||
elif i == GGML_TYPE_Q2_K:
|
||||
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
|
||||
elif i == GGML_TYPE_Q3_K:
|
||||
|
@ -2181,8 +2268,6 @@ async def main():
|
|||
|
||||
tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
|
||||
|
||||
tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
|
||||
|
||||
# get_rows
|
||||
for i in range(0, VK_NUM_TYPES):
|
||||
stream.clear()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue