Speed up q4_0 dequant code, enable mmq for q4_0
This commit is contained in:
parent
93cdea1d7b
commit
6314096db9
3 changed files with 5147 additions and 6522 deletions
11538
ggml-vulkan-shaders.hpp
11538
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
|
@ -957,7 +957,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
||||||
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
||||||
|
|
||||||
std::initializer_list<uint32_t> warptile_mmq_regular = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
||||||
|
|
||||||
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
|
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
|
||||||
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
|
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
|
||||||
|
@ -989,7 +989,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
|
||||||
|
@ -1012,7 +1012,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
@ -1461,6 +1461,20 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||||
return ctx->device->pipeline_dequant[type];
|
return ctx->device->pipeline_dequant[type];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_mat(ggml_backend_vk_context * ctx, ggml_type type) {
|
||||||
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
|
std::cerr << "ggml_vk_get_dequantize_mul_mat_mat()" << std::endl;
|
||||||
|
#endif
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx->device->pipeline_dequant_mul_mat_mat[type];
|
||||||
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
|
std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
|
||||||
|
@ -2428,10 +2442,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
|
||||||
const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
|
const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
|
const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig;
|
vk_pipeline dmmm = ggml_vk_get_dequantize_mul_mat_mat(ctx, src0->type);
|
||||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
|
|
||||||
|
const bool qx_needs_dequant = (dmmm == nullptr && src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || x_non_contig;
|
||||||
|
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
@ -2445,12 +2461,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
|
||||||
|
|
||||||
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned);
|
vk_pipeline pipeline = dmmm != nullptr ? dmmm : ggml_vk_guess_matmul_pipeline(ctx, true, !y_f32_kernel, ne01, ne11, aligned);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne;
|
const uint64_t x_sz = dmmm != nullptr ? ggml_nbytes(src0) : sizeof(ggml_fp16_t) * x_ne;
|
||||||
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
||||||
const uint64_t d_sz = sizeof(float) * d_ne;
|
const uint64_t d_sz = sizeof(float) * d_ne;
|
||||||
|
|
||||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||||
|
@ -2481,7 +2497,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
|
||||||
} else {
|
} else {
|
||||||
d_X = d_Qx;
|
d_X = d_Qx;
|
||||||
x_buf_offset = qx_buf_offset;
|
x_buf_offset = qx_buf_offset;
|
||||||
GGML_ASSERT(qx_sz == x_sz); // NOLINT
|
GGML_ASSERT(qx_sz == x_sz);
|
||||||
}
|
}
|
||||||
if (qy_needs_dequant) {
|
if (qy_needs_dequant) {
|
||||||
d_Y = ctx->prealloc_y;
|
d_Y = ctx->prealloc_y;
|
||||||
|
@ -3691,9 +3707,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < y_ne; i++) {
|
for (size_t i = 0; i < y_ne; i++) {
|
||||||
if (std::is_same<float, Y_TYPE>()) {
|
if (std::is_same<float, Y_TYPE>()) {
|
||||||
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
// y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
||||||
|
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
||||||
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
|
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
|
||||||
y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
// y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
||||||
|
y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
@ -3791,6 +3809,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||||
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
||||||
std::cerr << "Actual result: " << std::endl << std::endl;
|
std::cerr << "Actual result: " << std::endl << std::endl;
|
||||||
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||||
|
std::cerr << std::endl;
|
||||||
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
|
||||||
std::cerr << "Expected result: " << std::endl << std::endl;
|
std::cerr << "Expected result: " << std::endl << std::endl;
|
||||||
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||||
|
|
||||||
|
@ -4191,7 +4211,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||||
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
||||||
|
|
||||||
for (size_t i = 0; i < y_ne; i++) {
|
for (size_t i = 0; i < y_ne; i++) {
|
||||||
y[i] = rand() / (float)RAND_MAX;
|
// y[i] = rand() / (float)RAND_MAX;
|
||||||
|
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
|
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
|
||||||
|
@ -4279,6 +4300,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||||
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
||||||
std::cerr << "Actual result: " << std::endl << std::endl;
|
std::cerr << "Actual result: " << std::endl << std::endl;
|
||||||
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||||
|
std::cerr << std::endl;
|
||||||
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
|
||||||
|
std::cerr << std::endl;
|
||||||
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 25, first_err_b);
|
||||||
|
std::cerr << std::endl;
|
||||||
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 35, first_err_b);
|
||||||
std::cerr << "Expected result: " << std::endl << std::endl;
|
std::cerr << "Expected result: " << std::endl << std::endl;
|
||||||
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||||
|
|
||||||
|
@ -4472,7 +4499,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
||||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
|
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
|
||||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
|
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
|
||||||
|
|
||||||
|
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, GGML_TYPE_Q4_0);
|
||||||
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
|
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
|
||||||
|
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
|
||||||
|
|
||||||
std::cerr << std::endl;
|
std::cerr << std::endl;
|
||||||
|
|
||||||
|
|
|
@ -309,20 +309,22 @@ void main() {
|
||||||
mulmat_load_scalar = """
|
mulmat_load_scalar = """
|
||||||
#if LOAD_VEC_A == 8
|
#if LOAD_VEC_A == 8
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx][0].x);
|
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||||
|
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||||
#elif LOAD_VEC_A == 4
|
#elif LOAD_VEC_A == 4
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx].x);
|
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx].y);
|
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx].z);
|
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx].w);
|
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||||
#else
|
#else
|
||||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||||
|
@ -334,34 +336,39 @@ mulmat_load_scalar = """
|
||||||
|
|
||||||
mulmat_load_q4_0 = """
|
mulmat_load_q4_0 = """
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
[[unroll]] for (uint iqs = 0; iqs < 16; iqs++) {
|
|
||||||
const float d = float(data_a[idx].d);
|
|
||||||
const uint vui = uint(data_a[idx].qs[iqs]);
|
|
||||||
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
|
||||||
|
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 0 ] = FLOAT_TYPE(v.x);
|
const uint ib = idx / 16;
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 16] = FLOAT_TYPE(v.y);
|
const uint iqs = idx & 0xF;
|
||||||
}"""
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
|
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||||
|
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);"""
|
||||||
|
|
||||||
mulmat_body2 = """
|
mulmat_body2 = """
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||||
#if LOAD_VEC_B == 8
|
#if LOAD_VEC_B == 8
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx][0].w);
|
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 4] = FLOAT_TYPE(data_b[idx][1].x);
|
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 5] = FLOAT_TYPE(data_b[idx][1].y);
|
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
||||||
|
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||||
#elif LOAD_VEC_B == 4
|
#elif LOAD_VEC_B == 4
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx].x);
|
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx].y);
|
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx].z);
|
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx].w);
|
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||||
|
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
||||||
#else
|
#else
|
||||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||||
|
@ -2295,7 +2302,7 @@ async def main():
|
||||||
|
|
||||||
stream.clear()
|
stream.clear()
|
||||||
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
|
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
|
||||||
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 32, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||||
|
|
||||||
# Shaders where precision is needed, so no fp16 version
|
# Shaders where precision is needed, so no fp16 version
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue