Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

This commit is contained in:
0cc4m 2024-02-10 19:54:10 +01:00
parent 5169f928c7
commit 2a0cf851d4
3 changed files with 3011 additions and 4789 deletions

File diff suppressed because it is too large Load diff

View file

@ -992,13 +992,12 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
// dequant shaders // dequant shaders
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 5 * sizeof(uint32_t), { 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16", dequant_f16_len, dequant_f16_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
@ -3436,7 +3435,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
} }
} }
ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it); ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
if (split_k > 1) { if (split_k > 1) {
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@ -3482,7 +3481,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
for (size_t i = 0; i < num_it; i++) { for (size_t i = 0; i < num_it; i++) {
ggml_vk_ctx_begin(ctx, subctx); ggml_vk_ctx_begin(ctx, subctx);
ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n); ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);
} }
@ -3600,7 +3599,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_vk_destroy_buffer(d_Y); ggml_vk_destroy_buffer(d_Y);
ggml_vk_destroy_buffer(d_D); ggml_vk_destroy_buffer(d_D);
ggml_pipeline_cleanup(*p); ggml_pipeline_cleanup(p);
ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
free(x); free(x);
@ -3836,9 +3835,12 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
std::vector<int64_t> hist_cur(1 << 4, 0); std::vector<int64_t> hist_cur(1 << 4, 0);
vk_pipeline& p = ctx->device->pipeline_dequant[quant]; vk_pipeline p = ctx->device->pipeline_dequant[quant];
switch(quant) { switch(quant) {
case GGML_TYPE_F32:
memcpy(qx, x, sizeof(float) * ne);
break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data()); ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
break; break;
@ -3894,12 +3896,34 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0; double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16); ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
int first_err = -1;
double avg_err = 0.0; double avg_err = 0.0;
for (size_t i = 0; i < ne; i++) { for (size_t i = 0; i < ne; i++) {
avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i])); double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
avg_err += error;
if (first_err < 0 && error > 0.05) {
first_err = i;
}
} }
std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl; avg_err /= ne;
std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
if (avg_err > 0.1) {
std::cerr << "first_error = " << first_err << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl;
for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
}
std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
std::cerr << x[i] << ", ";
}
std::cerr << std::endl;
}
ggml_vk_destroy_buffer(x_buf); ggml_vk_destroy_buffer(x_buf);
ggml_vk_destroy_buffer(qx_buf); ggml_vk_destroy_buffer(qx_buf);
@ -4062,6 +4086,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
ggml_vk_test_transfer(ctx, 8192 * 1000, false); ggml_vk_test_transfer(ctx, 8192 * 1000, false);
ggml_vk_test_transfer(ctx, 8192 * 1000, true); ggml_vk_test_transfer(ctx, 8192 * 1000, true);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0); ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1); ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0); ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);

View file

@ -71,7 +71,7 @@ struct block_q5_1
{ {
float16_t d; float16_t d;
float16_t m; float16_t m;
uint qh; uint16_t qh[2];
uint8_t qs[16]; uint8_t qs[16];
}; };
@ -187,7 +187,8 @@ v = (v - 16.0f) * d;
shader_q5_1_dequant_func = """ shader_q5_1_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
const float m = float(data_a[ib].m); \ const float m = float(data_a[ib].m); \
const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | uint(data_a[ib].qh[0]); \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
const uint vui = uint(data_a[ib].qs[iqs]); \ const uint vui = uint(data_a[ib].qs[iqs]); \
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = v*d + m; v = v*d + m;
@ -449,35 +450,21 @@ layout (push_constant) uniform parameter
} p; } p;
""" """
dequant_body = """ dequant_f32_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() { void main() {
const uint i = gl_GlobalInvocationID.x; const uint i = gl_GlobalInvocationID.x * 16;
// Transposed if (i >= p.nel) {
const uint row = i % (p.K / QUANT_K);
const uint col = i / (p.K / QUANT_K);
if (row * QUANT_K >= p.K || col >= p.M) {
return; return;
} }
const uint stride_a = p.stride_a / QUANT_K; [[unroll]] for (uint l = 0; l < 16; l++) {
data_b[i + l] = D_TYPE(data_a[i + l]);
const uint ib = col * stride_a + row;
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
const uint step = QUANT_R == 1 ? 2 : 1;
[[unroll]] for (uint iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
DEQUANT_FUNC
data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y);
} }
} }
""" """
@ -513,6 +500,134 @@ void main() {
} }
""" """
dequant_q4_1_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
}
}
"""
dequant_q5_0_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
}
}
"""
dequant_q5_1_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint qh = uint(data_a[ib].qh[1]) << 16 | uint(data_a[ib].qh[0]);
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
}
}
"""
dequant_q8_0_body = """
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 16*il;
const float d = float(data_a[ib].d);
const uint q_idx = 16*il;
[[unroll]] for (uint l = 0; l < 16; l += 2) {
data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
}
}
"""
# K-quants # K-quants
dequant_q2_K_body = """ dequant_q2_K_body = """
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
@ -1380,34 +1495,6 @@ void main() {
} }
""" """
# F16 to F32
f32_to_f16_src = """#version 450
#extension GL_EXT_shader_16bit_storage : require
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float16_t data_b[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() {
const int row = int(gl_GlobalInvocationID.x % p.K);
const int col = int(gl_GlobalInvocationID.x / p.K);
if (row < p.K && col < p.M) {
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
}
}
"""
generic_head = """ generic_head = """
#version 450 #version 450
@ -2154,18 +2241,18 @@ async def main():
stream.extend((dequant_head, shader_int8_ext, shader_f32)) stream.extend((dequant_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16: if i == GGML_TYPE_F32:
stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body)) stream.append(dequant_f32_body)
elif i == GGML_TYPE_Q4_0: elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, dequant_q4_0_body)) stream.extend((shader_q4_0_defines, dequant_q4_0_body))
elif i == GGML_TYPE_Q4_1: elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body)) stream.extend((shader_q4_1_defines, dequant_q4_1_body))
elif i == GGML_TYPE_Q5_0: elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body)) stream.extend((shader_q5_0_defines, dequant_q5_0_body))
elif i == GGML_TYPE_Q5_1: elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body)) stream.extend((shader_q5_1_defines, dequant_q5_1_body))
elif i == GGML_TYPE_Q8_0: elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body)) stream.extend((shader_q8_0_defines, dequant_q8_0_body))
elif i == GGML_TYPE_Q2_K: elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, dequant_q2_K_body)) stream.extend((shader_q2_K_defines, dequant_q2_K_body))
elif i == GGML_TYPE_Q3_K: elif i == GGML_TYPE_Q3_K:
@ -2181,8 +2268,6 @@ async def main():
tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"})) tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
# get_rows # get_rows
for i in range(0, VK_NUM_TYPES): for i in range(0, VK_NUM_TYPES):
stream.clear() stream.clear()