Add argsort

Basic q4_0 mmq shader and unit test
This commit is contained in:
0cc4m 2024-02-25 17:06:10 +01:00
parent 0caf8dc906
commit aa0f428e2a
3 changed files with 12158 additions and 8497 deletions

File diff suppressed because it is too large Load diff

View file

@ -114,6 +114,8 @@ struct vk_device {
vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
vk_pipeline pipeline_dequant[VK_NUM_TYPES];
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
@ -136,6 +138,7 @@ struct vk_device {
vk_pipeline pipeline_soft_max_f32;
vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_argsort_f32;
std::vector<vk_pipeline_ref> pipelines;
@ -261,6 +264,11 @@ struct vk_op_rope_neox_push_constants {
float inv_ndims;
};
struct vk_op_argsort_push_constants {
uint32_t ncols;
bool ascending;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -937,6 +945,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_mmq_regular = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
std::array<uint32_t, 3> s_wg_denoms = { 32, 32, 1 };
@ -966,6 +976,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
} else {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
@ -987,6 +999,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
}
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
@ -1064,6 +1078,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
}
static void ggml_vk_print_gpu_info(size_t idx) {
@ -1851,6 +1867,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
ggml_vk_submit(subctx, ctx->fence);
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
ctx->device->device.resetFences({ ctx->fence });
ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
}
}
@ -1947,6 +1964,7 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
for (auto& cpy : subctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
}
}
@ -2802,6 +2820,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
}
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
}
static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
// guaranteed to be an integer due to the check in ggml_can_repeat
const uint64_t ne0 = dst->ne[0];
@ -2981,6 +3003,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
}
case GGML_OP_ARGSORT:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argsort_f32;
}
return nullptr;
default:
return nullptr;
}
@ -3358,6 +3385,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
}
}
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int32_t * op_params = (int32_t *)dst->op_params;
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ASC });
}
static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
// If backend is CPU, data from src0 has to be copied off the device
if (dst->backend == GGML_BACKEND_CPU) {
@ -3874,6 +3906,48 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
}
}
static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
std::vector<int64_t> hist_cur(1 << 4, 0);
switch(quant) {
case GGML_TYPE_F32:
memcpy(to, from, sizeof(float) * ne);
break;
case GGML_TYPE_Q4_0:
ggml_quantize_q4_0(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q4_1:
ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_0:
ggml_quantize_q5_0(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_1:
ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q8_0:
ggml_quantize_q8_0(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q2_K:
ggml_quantize_q2_K(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q3_K:
ggml_quantize_q3_K(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q4_K:
ggml_quantize_q4_K(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_K:
ggml_quantize_q5_K(from, to, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q6_K:
ggml_quantize_q6_K(from, to, ne, ne, hist_cur.data());
break;
default:
GGML_ASSERT(false);
}
}
static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
@ -3891,47 +3965,9 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
x[i] = rand() / (float)RAND_MAX;
}
std::vector<int64_t> hist_cur(1 << 4, 0);
vk_pipeline p = ctx->device->pipeline_dequant[quant];
switch(quant) {
case GGML_TYPE_F32:
memcpy(qx, x, sizeof(float) * ne);
break;
case GGML_TYPE_Q4_0:
ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q4_1:
ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_0:
ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_1:
ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q8_0:
ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q2_K:
ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q3_K:
ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q4_K:
ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q5_K:
ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
break;
case GGML_TYPE_Q6_K:
ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
break;
default:
GGML_ASSERT(false);
}
ggml_vk_quantize_data(x, qx, ne, quant);
ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
@ -3990,6 +4026,158 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
free(qx);
free(x_chk);
}
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, ggml_type quant) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
#endif
const size_t x_ne = m * k * batch;
const size_t y_ne = k * n * batch;
const size_t d_ne = m * n * batch;
const size_t x_sz = sizeof(float) * x_ne;
const size_t y_sz = sizeof(float) * y_ne;
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
const size_t d_sz = sizeof(float) * d_ne;
float * x = (float *) malloc(x_sz);
float * y = (float *) malloc(y_sz);
void * qx = malloc(qx_sz);
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
float * d = (float *) malloc(d_sz);
float * d_chk = (float *) malloc(d_sz);
for (size_t i = 0; i < x_ne; i++) {
x[i] = rand() / (float)RAND_MAX;
}
vk_pipeline p = ctx->device->pipeline_dequant_mul_mat_mat[quant];
ggml_vk_quantize_data(x, qx, x_ne, quant);
for (size_t i = 0; i < y_ne; i++) {
y[i] = rand() / (float)RAND_MAX;
}
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
// Resize buffer
if (ctx->prealloc_split_k != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
}
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
}
}
ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_ctx_begin(ctx, subctx);
ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
ggml_vk_ctx_end(subctx);
}
auto begin = std::chrono::high_resolution_clock::now();
ggml_vk_submit(subctx, ctx->fence);
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
ctx->device->device.resetFences({ ctx->fence });
auto end = std::chrono::high_resolution_clock::now();
double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
ggml_init_params iparams = {
/*.mem_size =*/ 1024*1024*1024,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context * ggml_ctx = ggml_init(iparams);
ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
src0_ggml->data = qx;
src1_ggml->data = y;
tensor_ggml->data = d_chk;
ctx->disable = true;
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph, tensor_ggml);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
ctx->disable = false;
ggml_free(ggml_ctx);
double avg_err = 0.0;
int first_err_n = -1;
int first_err_m = -1;
int first_err_b = -1;
for (size_t i = 0; i < m*n*batch; i++) {
double err = std::fabs(d[i] - d_chk[i]);
avg_err += err;
if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
first_err_b = i / (m * n);
first_err_n = (i % (m * n)) / m;
first_err_m = (i % (m * n)) % m;
}
}
avg_err /= m * n;
std::cerr << "TEST MMQ " << ggml_type_name(quant) << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
if (avg_err > 0.1 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
if (split_k > 1) {
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
std::cerr << "d_buf0: " << std::endl << std::endl;
ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "d_buf1: " << std::endl << std::endl;
ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "d_buf2: " << std::endl << std::endl;
ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "d_buf3: " << std::endl << std::endl;
ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
free(split_k_buf);
}
}
ggml_vk_destroy_buffer(qx_buf);
ggml_vk_destroy_buffer(y_buf);
ggml_vk_destroy_buffer(d_buf);
free(x);
free(qx);
free(y);
free(d);
free(d_chk);
}
#endif
static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
@ -4002,18 +4190,8 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
return extra;
}
static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
GGML_ASSERT(node != nullptr);
for (int i = graph->n_nodes - 1; i >= 0; i--) {
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (graph->nodes[i]->src[j] == node) {
return graph->nodes[i];
}
}
}
return nullptr;
static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
}
static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
@ -4024,7 +4202,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU));
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
return;
}
@ -4055,7 +4233,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
int split_k;
if (node->op == GGML_OP_MUL_MAT) {
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
} else {
split_k = 1;
@ -4096,6 +4274,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ARGSORT:
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
@ -4108,6 +4287,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
if (ctx->prealloc_size_qx < qx_sz) {
ctx->prealloc_size_qx = qx_sz;
}
@ -4144,17 +4324,21 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
ggml_vk_test_transfer(ctx, 8192 * 1000, false);
ggml_vk_test_transfer(ctx, 8192 * 1000, true);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
std::cerr << std::endl;
const std::vector<size_t> vals {
8, 8, 8,
@ -4242,7 +4426,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU);
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
return;
}
@ -4288,7 +4472,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_NONE:
case GGML_OP_ARGSORT:
break;
default:
if (any_on_device) {
@ -4376,10 +4562,17 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_ROPE:
ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
break;
case GGML_OP_ARGSORT:
ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
break;
case GGML_OP_MUL_MAT:
ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
break;
case GGML_OP_MUL_MAT_ID:
ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
break;
default:
return;
@ -4406,7 +4599,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
return false;
}
@ -4432,6 +4625,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_NONE:
case GGML_OP_ARGSORT:
extra = (ggml_tensor_extra_gpu *) tensor->extra;
break;
@ -4447,6 +4641,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
return false;
}
@ -5164,6 +5359,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a;
struct ggml_tensor * b;
@ -5237,6 +5433,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ARGSORT:
return true;
default:
return false;

View file

@ -207,12 +207,15 @@ mulmat_head = """#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#ifndef LOAD_VEC
#define LOAD_VEC 1
#ifndef LOAD_VEC_A
#define LOAD_VEC_A 1
#endif
#ifndef LOAD_VEC_B
#define LOAD_VEC_B 1
#endif
"""
mulmat_body = """
mulmat_body1 = """
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@ -241,7 +244,7 @@ layout (push_constant) uniform parameter
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
@ -278,16 +281,19 @@ void main() {
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
@ -298,61 +304,77 @@ void main() {
}
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride) {
#if LOAD_VEC == 8
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC == 4
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {"""
mulmat_load_scalar = """
#if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx].x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
}
#endif
"""
mulmat_load_q4_0 = """
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
[[unroll]] for (uint iqs = 0; iqs < 16; iqs++) {
const float d = float(data_a[idx].d);
const uint vui = uint(data_a[idx].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 0 ] = FLOAT_TYPE(v.x);
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 16] = FLOAT_TYPE(v.y);
}"""
mulmat_body2 = """
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride) {
#if LOAD_VEC == 8
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC == 4
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx][0].z);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx][0].w);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 4] = FLOAT_TYPE(data_b[idx][1].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 5] = FLOAT_TYPE(data_b[idx][1].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx].w);
#else
if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
}
#endif
}
barrier();
pos_a += BK / LOAD_VEC;
pos_b += BK / LOAD_VEC;
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (uint i = 0; i < BK; i++) {
// Load from shared into cache
@ -2085,6 +2107,65 @@ void main() {
}
"""
argsort_src = """
#version 450
#extension GL_EXT_shader_16bit_storage : require
layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
bool ascending;
} p;
void swap(uint idx0, uint idx1) {
int tmp = data_d[idx0];
data_d[idx0] = data_d[idx1];
data_d[idx1] = tmp;
}
void main() {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
if (col >= p.ncols) {
return;
}
const uint a_idx = row * p.ncols;
const uint d_idx = row * p.ncols;
// initialize indices
if (col < p.ncols) {
data_d[col] = col;
}
barrier();
for (uint k = 2; k <= p.ncols; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
swap(d_idx + col, d_idx + ixj);
}
} else {
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
swap(d_idx + col, d_idx + ixj);
}
}
}
barrier();
}
}
}
"""
GLSLC = "glslc"
VK_NUM_TYPES = 16
@ -2202,15 +2283,19 @@ async def main():
vec_type = "vec4"
stream.clear()
stream.extend((mulmat_head, shader_float_type, mulmat_body))
stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
stream.clear()
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 32, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
# Shaders where precision is needed, so no fp16 version
@ -2338,6 +2423,8 @@ async def main():
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
# Helper to decorate tasks with semaphore acquisition.
async def withSemaphore(sem, task):
async with sem: