Add argsort
Basic q4_0 mmq shader and unit test
This commit is contained in:
parent
0caf8dc906
commit
aa0f428e2a
3 changed files with 12158 additions and 8497 deletions
20129
ggml-vulkan-shaders.hpp
20129
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
329
ggml-vulkan.cpp
329
ggml-vulkan.cpp
|
@ -114,6 +114,8 @@ struct vk_device {
|
|||
vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
|
||||
vk_pipeline pipeline_matmul_split_k_reduce;
|
||||
|
||||
vk_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
|
||||
|
||||
vk_pipeline pipeline_dequant[VK_NUM_TYPES];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
|
||||
|
||||
|
@ -136,6 +138,7 @@ struct vk_device {
|
|||
vk_pipeline pipeline_soft_max_f32;
|
||||
vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
|
||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
||||
vk_pipeline pipeline_argsort_f32;
|
||||
|
||||
std::vector<vk_pipeline_ref> pipelines;
|
||||
|
||||
|
@ -261,6 +264,11 @@ struct vk_op_rope_neox_push_constants {
|
|||
float inv_ndims;
|
||||
};
|
||||
|
||||
struct vk_op_argsort_push_constants {
|
||||
uint32_t ncols;
|
||||
bool ascending;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
|
@ -937,6 +945,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
|||
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
||||
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
||||
|
||||
std::initializer_list<uint32_t> warptile_mmq_regular = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
||||
|
||||
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
|
||||
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
|
||||
std::array<uint32_t, 3> s_wg_denoms = { 32, 32, 1 };
|
||||
|
@ -966,6 +976,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
|
||||
|
@ -987,6 +999,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_regular, s_align);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
@ -1064,6 +1078,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
|||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
|
@ -1851,6 +1867,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
|
|||
ggml_vk_submit(subctx, ctx->fence);
|
||||
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
|
||||
ctx->device->device.resetFences({ ctx->fence });
|
||||
ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1947,6 +1964,7 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
|
|||
for (auto& cpy : subctx->out_memcpys) {
|
||||
memcpy(cpy.dst, cpy.src, cpy.n);
|
||||
}
|
||||
ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2802,6 +2820,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
|
||||
}
|
||||
|
||||
static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const uint64_t ne0 = dst->ne[0];
|
||||
|
@ -2981,6 +3003,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_ARGSORT:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
return ctx->device->pipeline_argsort_f32;
|
||||
}
|
||||
return nullptr;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -3358,6 +3385,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
int32_t * op_params = (int32_t *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ASC });
|
||||
}
|
||||
|
||||
static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
// If backend is CPU, data from src0 has to be copied off the device
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
|
@ -3874,6 +3906,48 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
switch(quant) {
|
||||
case GGML_TYPE_F32:
|
||||
memcpy(to, from, sizeof(float) * ne);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
ggml_quantize_q4_0(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
ggml_quantize_q5_0(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
ggml_quantize_q8_0(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
ggml_quantize_q2_K(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
ggml_quantize_q3_K(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_quantize_q4_K(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
ggml_quantize_q5_K(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_quantize_q6_K(from, to, ne, ne, hist_cur.data());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
|
||||
#ifdef GGML_VULKAN_DEBUG
|
||||
std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
|
||||
|
@ -3891,47 +3965,9 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|||
x[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
vk_pipeline p = ctx->device->pipeline_dequant[quant];
|
||||
|
||||
switch(quant) {
|
||||
case GGML_TYPE_F32:
|
||||
memcpy(qx, x, sizeof(float) * ne);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
ggml_vk_quantize_data(x, qx, ne, quant);
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
|
||||
|
||||
|
@ -3990,6 +4026,158 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|||
free(qx);
|
||||
free(x_chk);
|
||||
}
|
||||
|
||||
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, ggml_type quant) {
|
||||
#ifdef GGML_VULKAN_DEBUG
|
||||
std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
|
||||
#endif
|
||||
const size_t x_ne = m * k * batch;
|
||||
const size_t y_ne = k * n * batch;
|
||||
const size_t d_ne = m * n * batch;
|
||||
|
||||
const size_t x_sz = sizeof(float) * x_ne;
|
||||
const size_t y_sz = sizeof(float) * y_ne;
|
||||
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
||||
const size_t d_sz = sizeof(float) * d_ne;
|
||||
float * x = (float *) malloc(x_sz);
|
||||
float * y = (float *) malloc(y_sz);
|
||||
void * qx = malloc(qx_sz);
|
||||
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
float * d = (float *) malloc(d_sz);
|
||||
float * d_chk = (float *) malloc(d_sz);
|
||||
|
||||
for (size_t i = 0; i < x_ne; i++) {
|
||||
x[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
vk_pipeline p = ctx->device->pipeline_dequant_mul_mat_mat[quant];
|
||||
|
||||
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
||||
|
||||
for (size_t i = 0; i < y_ne; i++) {
|
||||
y[i] = rand() / (float)RAND_MAX;
|
||||
}
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
|
||||
// Resize buffer
|
||||
if (ctx->prealloc_split_k != nullptr) {
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
||||
}
|
||||
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
|
||||
ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
|
||||
|
||||
vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_ctx_begin(ctx, subctx);
|
||||
ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
|
||||
ggml_vk_ctx_end(subctx);
|
||||
}
|
||||
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
ggml_vk_submit(subctx, ctx->fence);
|
||||
VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
|
||||
ctx->device->device.resetFences({ ctx->fence });
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||
ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
|
||||
|
||||
ggml_init_params iparams = {
|
||||
/*.mem_size =*/ 1024*1024*1024,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context * ggml_ctx = ggml_init(iparams);
|
||||
|
||||
ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
|
||||
ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
|
||||
ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
|
||||
|
||||
src0_ggml->data = qx;
|
||||
src1_ggml->data = y;
|
||||
tensor_ggml->data = d_chk;
|
||||
|
||||
ctx->disable = true;
|
||||
|
||||
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
|
||||
ggml_build_forward_expand(cgraph, tensor_ggml);
|
||||
|
||||
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
|
||||
|
||||
ctx->disable = false;
|
||||
|
||||
ggml_free(ggml_ctx);
|
||||
|
||||
double avg_err = 0.0;
|
||||
int first_err_n = -1;
|
||||
int first_err_m = -1;
|
||||
int first_err_b = -1;
|
||||
|
||||
for (size_t i = 0; i < m*n*batch; i++) {
|
||||
double err = std::fabs(d[i] - d_chk[i]);
|
||||
avg_err += err;
|
||||
|
||||
if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
|
||||
first_err_b = i / (m * n);
|
||||
first_err_n = (i % (m * n)) / m;
|
||||
first_err_m = (i % (m * n)) % m;
|
||||
}
|
||||
}
|
||||
|
||||
avg_err /= m * n;
|
||||
|
||||
std::cerr << "TEST MMQ " << ggml_type_name(quant) << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
|
||||
|
||||
if (avg_err > 0.1 || std::isnan(avg_err)) {
|
||||
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
||||
std::cerr << "Actual result: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
std::cerr << "Expected result: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
if (split_k > 1) {
|
||||
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
|
||||
ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
|
||||
|
||||
std::cerr << "d_buf0: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
std::cerr << "d_buf1: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
std::cerr << "d_buf2: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
std::cerr << "d_buf3: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
free(split_k_buf);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_destroy_buffer(qx_buf);
|
||||
ggml_vk_destroy_buffer(y_buf);
|
||||
ggml_vk_destroy_buffer(d_buf);
|
||||
|
||||
free(x);
|
||||
free(qx);
|
||||
free(y);
|
||||
free(d);
|
||||
free(d_chk);
|
||||
}
|
||||
#endif
|
||||
|
||||
static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
|
||||
|
@ -4002,18 +4190,8 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
|
|||
return extra;
|
||||
}
|
||||
|
||||
static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
|
||||
GGML_ASSERT(node != nullptr);
|
||||
|
||||
for (int i = graph->n_nodes - 1; i >= 0; i--) {
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (graph->nodes[i]->src[j] == node) {
|
||||
return graph->nodes[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
|
||||
return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
|
||||
}
|
||||
|
||||
static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
|
||||
|
@ -4024,7 +4202,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|
|||
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
|| (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_GPU));
|
||||
|
||||
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
|
||||
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -4055,7 +4233,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|
|||
const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
|
||||
|
||||
int split_k;
|
||||
if (node->op == GGML_OP_MUL_MAT) {
|
||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||
split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||
} else {
|
||||
split_k = 1;
|
||||
|
@ -4096,6 +4274,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ARGSORT:
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
|
@ -4108,6 +4287,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|
|||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (ctx->prealloc_size_qx < qx_sz) {
|
||||
ctx->prealloc_size_qx = qx_sz;
|
||||
}
|
||||
|
@ -4144,17 +4324,21 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|||
ggml_vk_test_transfer(ctx, 8192 * 1000, false);
|
||||
ggml_vk_test_transfer(ctx, 8192 * 1000, true);
|
||||
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
|
||||
ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_F32);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
|
||||
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
|
||||
|
||||
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
|
||||
|
||||
std::cerr << std::endl;
|
||||
|
||||
const std::vector<size_t> vals {
|
||||
8, 8, 8,
|
||||
|
@ -4242,7 +4426,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
|| (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_GPU);
|
||||
|
||||
if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
|
||||
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -4288,7 +4472,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_ARGSORT:
|
||||
break;
|
||||
default:
|
||||
if (any_on_device) {
|
||||
|
@ -4376,10 +4562,17 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_ROPE:
|
||||
ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
|
@ -4406,7 +4599,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
|
||||
|
||||
if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
|
||||
if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -4432,6 +4625,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_ARGSORT:
|
||||
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
|
||||
break;
|
||||
|
@ -4447,6 +4641,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -5164,6 +5359,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
struct ggml_tensor * a;
|
||||
struct ggml_tensor * b;
|
||||
|
@ -5237,6 +5433,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|||
case GGML_OP_CONT:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_ARGSORT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
@ -207,12 +207,15 @@ mulmat_head = """#version 450
|
|||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#ifndef LOAD_VEC
|
||||
#define LOAD_VEC 1
|
||||
#ifndef LOAD_VEC_A
|
||||
#define LOAD_VEC_A 1
|
||||
#endif
|
||||
#ifndef LOAD_VEC_B
|
||||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
"""
|
||||
|
||||
mulmat_body = """
|
||||
mulmat_body1 = """
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
|
@ -241,7 +244,7 @@ layout (push_constant) uniform parameter
|
|||
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
layout (constant_id = 3) const uint BK = 16;
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
layout (constant_id = 4) const uint WM = 32;
|
||||
layout (constant_id = 5) const uint WN = 32;
|
||||
layout (constant_id = 6) const uint WMITER = 2;
|
||||
|
@ -278,16 +281,19 @@ void main() {
|
|||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
|
||||
const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
|
||||
const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
||||
|
||||
const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
const uint start_k = ik * p.k_split;
|
||||
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||
|
||||
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
|
||||
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
|
||||
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||
uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
|
@ -298,61 +304,77 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride) {
|
||||
#if LOAD_VEC == 8
|
||||
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||
#elif LOAD_VEC == 4
|
||||
const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {"""
|
||||
|
||||
mulmat_load_scalar = """
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx][0].x);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 0] = FLOAT_TYPE(data_a[idx].x);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||
#else
|
||||
if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
|
||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
|
||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||
} else {
|
||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
"""
|
||||
|
||||
mulmat_load_q4_0 = """
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 16; iqs++) {
|
||||
const float d = float(data_a[idx].d);
|
||||
const uint vui = uint(data_a[idx].qs[iqs]);
|
||||
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 0 ] = FLOAT_TYPE(v.x);
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A + iqs + 16] = FLOAT_TYPE(v.y);
|
||||
}"""
|
||||
|
||||
mulmat_body2 = """
|
||||
}
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride) {
|
||||
#if LOAD_VEC == 8
|
||||
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC == 4
|
||||
const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if LOAD_VEC_B == 8
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx][0].w);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 4] = FLOAT_TYPE(data_b[idx][1].x);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 5] = FLOAT_TYPE(data_b[idx][1].y);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 0] = FLOAT_TYPE(data_b[idx].x);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 1] = FLOAT_TYPE(data_b[idx].y);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B + 3] = FLOAT_TYPE(data_b[idx].w);
|
||||
#else
|
||||
if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
|
||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
|
||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC;
|
||||
pos_b += BK / LOAD_VEC;
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (uint i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
|
@ -2085,6 +2107,65 @@ void main() {
|
|||
}
|
||||
"""
|
||||
|
||||
argsort_src = """
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
bool ascending;
|
||||
} p;
|
||||
|
||||
void swap(uint idx0, uint idx1) {
|
||||
int tmp = data_d[idx0];
|
||||
data_d[idx0] = data_d[idx1];
|
||||
data_d[idx1] = tmp;
|
||||
}
|
||||
|
||||
void main() {
|
||||
// bitonic sort
|
||||
const int col = int(gl_LocalInvocationID.x);
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
|
||||
if (col >= p.ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint a_idx = row * p.ncols;
|
||||
const uint d_idx = row * p.ncols;
|
||||
|
||||
// initialize indices
|
||||
if (col < p.ncols) {
|
||||
data_d[col] = col;
|
||||
}
|
||||
barrier();
|
||||
|
||||
for (uint k = 2; k <= p.ncols; k *= 2) {
|
||||
for (uint j = k / 2; j > 0; j /= 2) {
|
||||
const uint ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) {
|
||||
swap(d_idx + col, d_idx + ixj);
|
||||
}
|
||||
} else {
|
||||
if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) {
|
||||
swap(d_idx + col, d_idx + ixj);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
GLSLC = "glslc"
|
||||
|
||||
VK_NUM_TYPES = 16
|
||||
|
@ -2202,15 +2283,19 @@ async def main():
|
|||
vec_type = "vec4"
|
||||
|
||||
stream.clear()
|
||||
stream.extend((mulmat_head, shader_float_type, mulmat_body))
|
||||
stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2))
|
||||
tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||
|
||||
tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
|
||||
|
||||
tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||
|
||||
stream.clear()
|
||||
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
|
||||
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 32, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
|
||||
|
||||
# Shaders where precision is needed, so no fp16 version
|
||||
|
||||
|
@ -2338,6 +2423,8 @@ async def main():
|
|||
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
|
||||
tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
|
||||
|
||||
# Helper to decorate tasks with semaphore acquisition.
|
||||
async def withSemaphore(sem, task):
|
||||
async with sem:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue