Speed up q4_0 dequant code, enable mmq for q4_0

This commit is contained in:
0cc4m 2024-02-28 22:14:14 +01:00
parent 93cdea1d7b
commit 6314096db9
3 changed files with 5147 additions and 6522 deletions

File diff suppressed because it is too large Load diff

View file

@ -957,7 +957,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_mmq_regular = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 }; std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 }; std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
@ -989,7 +989,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
} else { } else {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
@ -1012,7 +1012,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
} }
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
@ -1461,6 +1461,20 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
return ctx->device->pipeline_dequant[type]; return ctx->device->pipeline_dequant[type];
} }
static vk_pipeline ggml_vk_get_dequantize_mul_mat_mat(ggml_backend_vk_context * ctx, ggml_type type) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_get_dequantize_mul_mat_mat()" << std::endl;
#endif
switch (type) {
case GGML_TYPE_Q4_0:
break;
default:
return nullptr;
}
return ctx->device->pipeline_dequant_mul_mat_mat[type];
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) { static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl; std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
@ -2428,10 +2442,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0); const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1); const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig; vk_pipeline dmmm = ggml_vk_get_dequantize_mul_mat_mat(ctx, src0->type);
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
const bool qx_needs_dequant = (dmmm == nullptr && src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
@ -2445,12 +2461,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned); vk_pipeline pipeline = dmmm != nullptr ? dmmm : ggml_vk_guess_matmul_pipeline(ctx, true, !y_f32_kernel, ne01, ne11, aligned);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne; const uint64_t x_sz = dmmm != nullptr ? ggml_nbytes(src0) : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t d_sz = sizeof(float) * d_ne; const uint64_t d_sz = sizeof(float) * d_ne;
vk_buffer d_D = extra->buffer_gpu.lock(); vk_buffer d_D = extra->buffer_gpu.lock();
@ -2481,7 +2497,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
} else { } else {
d_X = d_Qx; d_X = d_Qx;
x_buf_offset = qx_buf_offset; x_buf_offset = qx_buf_offset;
GGML_ASSERT(qx_sz == x_sz); // NOLINT GGML_ASSERT(qx_sz == x_sz);
} }
if (qy_needs_dequant) { if (qy_needs_dequant) {
d_Y = ctx->prealloc_y; d_Y = ctx->prealloc_y;
@ -3691,9 +3707,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
} }
for (size_t i = 0; i < y_ne; i++) { for (size_t i = 0; i < y_ne; i++) {
if (std::is_same<float, Y_TYPE>()) { if (std::is_same<float, Y_TYPE>()) {
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) { } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -3791,6 +3809,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl; std::cerr << "Actual result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
std::cerr << "Expected result: " << std::endl << std::endl; std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
@ -4191,7 +4211,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ggml_vk_quantize_data(x, qx, x_ne, quant); ggml_vk_quantize_data(x, qx, x_ne, quant);
for (size_t i = 0; i < y_ne; i++) { for (size_t i = 0; i < y_ne; i++) {
y[i] = rand() / (float)RAND_MAX; // y[i] = rand() / (float)RAND_MAX;
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
} }
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it); ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
@ -4279,6 +4300,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl; std::cerr << "Actual result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 25, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 35, first_err_b);
std::cerr << "Expected result: " << std::endl << std::endl; std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
@ -4472,7 +4499,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K); // ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K); // ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0); ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
std::cerr << std::endl; std::cerr << std::endl;

View file

@ -309,20 +309,22 @@ void main() {
mulmat_load_scalar = """ mulmat_load_scalar = """
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx][0].x); const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx][0].w); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 4] = FLOAT_TYPE(data_a[idx][1].x); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 5] = FLOAT_TYPE(data_a[idx][1].y); buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 6] = FLOAT_TYPE(data_a[idx][1].z); buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx].x); const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
@ -334,34 +336,39 @@ mulmat_load_scalar = """
mulmat_load_q4_0 = """ mulmat_load_q4_0 = """
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
[[unroll]] for (uint iqs = 0; iqs < 16; iqs++) {
const float d = float(data_a[idx].d);
const uint vui = uint(data_a[idx].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 0 ] = FLOAT_TYPE(v.x); const uint ib = idx / 16;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 16] = FLOAT_TYPE(v.y); const uint iqs = idx & 0xF;
}"""
const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
mulmat_body2 = """ mulmat_body2 = """
} }
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8 #if LOAD_VEC_B == 8
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx][0].x); const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx][0].z); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx][0].w); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 4] = FLOAT_TYPE(data_b[idx][1].x); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 5] = FLOAT_TYPE(data_b[idx][1].y); buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 6] = FLOAT_TYPE(data_b[idx][1].z); buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 7] = FLOAT_TYPE(data_b[idx][1].w); buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4 #elif LOAD_VEC_B == 4
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx].x); const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx].w); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#else #else
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
@ -2295,7 +2302,7 @@ async def main():
stream.clear() stream.clear()
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 32, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
# Shaders where precision is needed, so no fp16 version # Shaders where precision is needed, so no fp16 version