Rework matmul pipeline selection

This commit is contained in:
0cc4m 2024-02-29 22:34:09 +01:00
parent 6314096db9
commit c3eba7c1c9
3 changed files with 1940 additions and 223 deletions

File diff suppressed because it is too large Load diff

View file

@ -89,6 +89,13 @@ typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
};
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_device {
vk::PhysicalDevice physical_device;
vk::PhysicalDeviceProperties properties;
@ -107,15 +114,12 @@ struct vk_device {
bool initialized;
size_t idx;
vk_pipeline pipeline_matmul_f32_l, pipeline_matmul_f32_m, pipeline_matmul_f32_s;
vk_pipeline pipeline_matmul_f32_aligned_l, pipeline_matmul_f32_aligned_m, pipeline_matmul_f32_aligned_s;
vk_pipeline pipeline_matmul_f16_l, pipeline_matmul_f16_m, pipeline_matmul_f16_s;
vk_pipeline pipeline_matmul_f16_aligned_l, pipeline_matmul_f16_aligned_m, pipeline_matmul_f16_aligned_s;
vk_pipeline pipeline_matmul_f16_f32_l, pipeline_matmul_f16_f32_m, pipeline_matmul_f16_f32_s;
vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
vk_matmul_pipeline pipeline_matmul_f32;
vk_matmul_pipeline pipeline_matmul_f16;
vk_matmul_pipeline pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
vk_pipeline pipeline_dequant[VK_NUM_TYPES];
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
@ -957,6 +961,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
@ -967,52 +973,67 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
uint32_t m_align = 64;
uint32_t s_align = 32;
ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
if (device->fp16) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
} else {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], "matmul_q4_0_f32_aligned", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
}
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32", mul_mat_vec_f16_f32_len, mul_mat_vec_f16_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
@ -1461,18 +1482,30 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
return ctx->device->pipeline_dequant[type];
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_mat(ggml_backend_vk_context * ctx, ggml_type type) {
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_get_dequantize_mul_mat_mat()" << std::endl;
std::cerr << "ggml_vk_get_mul_mat_mat_pipeline()" << std::endl;
#endif
switch (type) {
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return ctx->device->pipeline_matmul_f32;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
return ctx->device->pipeline_matmul_f16_f32;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return ctx->device->pipeline_matmul_f16;
}
GGML_ASSERT(src1_type == GGML_TYPE_F32);
switch (src0_type) {
case GGML_TYPE_Q4_0:
break;
default:
return nullptr;
}
return ctx->device->pipeline_dequant_mul_mat_mat[type];
return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
@ -2157,176 +2190,63 @@ static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * su
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl;
#endif
if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " = 4" << std::endl;
#endif
return 4;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " = 1" << std::endl;
#endif
return 1;
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, int m, int n) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
#endif
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
if (m <= 32 || n <= 32) {
return ctx->device->pipeline_matmul_f32_aligned_s->align;
return aligned ? mmp->a_s : mmp->s;
}
if (ctx->device->subgroup_size == 64 || m <= 64 || n <= 64) {
return ctx->device->pipeline_matmul_f32_aligned_m->align;
}
return ctx->device->pipeline_matmul_f32_aligned_l->align;
return aligned ? mmp->a_m : mmp->m;
GGML_UNUSED(ctx);
}
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
if (bit16_x && bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_aligned_s : ctx->device->pipeline_matmul_f16_s;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_aligned_m : ctx->device->pipeline_matmul_f16_m;
}
if (bit16_x && !bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_s : ctx->device->pipeline_matmul_f16_f32_s;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_m : ctx->device->pipeline_matmul_f16_f32_m;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
return aligned ? mmp->a_m : mmp->m;
GGML_UNUSED(ctx);
}
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f32_aligned_s : ctx->device->pipeline_matmul_f32_s;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f32_aligned_m : ctx->device->pipeline_matmul_f32_m;
static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
return aligned ? mmp->a_s : mmp->s;
GGML_UNUSED(ctx);
}
static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
if (bit16_x && bit16_y) {
return aligned ? ctx->device->pipeline_matmul_f16_aligned_m : ctx->device->pipeline_matmul_f16_m;
}
if (bit16_x && !bit16_y) {
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_m : ctx->device->pipeline_matmul_f16_f32_m;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
}
return aligned ? ctx->device->pipeline_matmul_f32_aligned_m : ctx->device->pipeline_matmul_f32_m;
}
static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
if (bit16_x && bit16_y) {
return aligned ? ctx->device->pipeline_matmul_f16_aligned_s : ctx->device->pipeline_matmul_f16_s;
}
if (bit16_x && !bit16_y) {
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_s : ctx->device->pipeline_matmul_f16_f32_s;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
}
return aligned ? ctx->device->pipeline_matmul_f32_aligned_s : ctx->device->pipeline_matmul_f32_s;
}
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
std::cerr << "ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")" << std::endl;
#endif
switch (ctx->device->vendor_id) {
case VK_VENDOR_ID_AMD:
return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
case VK_VENDOR_ID_APPLE:
return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
case VK_VENDOR_ID_INTEL:
return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
}
if (bit16_x && bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_aligned_s : ctx->device->pipeline_matmul_f16_s;
}
if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_aligned_m : ctx->device->pipeline_matmul_f16_m;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " L" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_aligned_l : ctx->device->pipeline_matmul_f16_l;
}
if (bit16_x && !bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_s : ctx->device->pipeline_matmul_f16_f32_s;
}
if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_m : ctx->device->pipeline_matmul_f16_f32_m;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " L" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f16_f32_aligned_l : ctx->device->pipeline_matmul_f16_f32_l;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
}
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f32_aligned_s : ctx->device->pipeline_matmul_f32_s;
return aligned ? mmp->a_s : mmp->s;
}
if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f32_aligned_m : ctx->device->pipeline_matmul_f32_m;
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " L" << std::endl;
std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
#endif
return aligned ? ctx->device->pipeline_matmul_f32_aligned_l : ctx->device->pipeline_matmul_f32_l;
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align;
}
static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
@ -2444,11 +2364,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
vk_pipeline dmmm = ggml_vk_get_dequantize_mul_mat_mat(ctx, src0->type);
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
const bool qx_needs_dequant = (dmmm == nullptr && src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || x_non_contig;
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
if (mmp == nullptr) {
// Fall back to dequant + f16 mulmat
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
}
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
@ -2456,16 +2381,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, ne01, ne11));
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
const bool aligned = ne10 == kpad;
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
vk_pipeline pipeline = dmmm != nullptr ? dmmm : ggml_vk_guess_matmul_pipeline(ctx, true, !y_f32_kernel, ne01, ne11, aligned);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = dmmm != nullptr ? ggml_nbytes(src0) : sizeof(ggml_fp16_t) * x_ne;
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t d_sz = sizeof(float) * d_ne;
@ -3595,39 +3520,39 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
std::string shname;
if (shader_size == 0) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_aligned_s;
p = ctx->device->pipeline_matmul_f32->a_s;
shname = "F32_ALIGNED_S";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_aligned_s;
p = ctx->device->pipeline_matmul_f16_f32->a_s;
shname = "F16_F32_ALIGNED_S";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_aligned_s;
p = ctx->device->pipeline_matmul_f16->a_s;
shname = "F16_ALIGNED_S";
} else {
GGML_ASSERT(false);
}
} else if (shader_size == 1) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_aligned_m;
p = ctx->device->pipeline_matmul_f32->a_m;
shname = "F32_ALIGNED_M";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_aligned_m;
p = ctx->device->pipeline_matmul_f16_f32->a_m;
shname = "F16_F32_ALIGNED_M";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_aligned_m;
p = ctx->device->pipeline_matmul_f16->a_m;
shname = "F16_ALIGNED_M";
} else {
GGML_ASSERT(false);
}
} else if (shader_size == 2) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_aligned_l;
p = ctx->device->pipeline_matmul_f32->a_l;
shname = "F32_ALIGNED_L";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_aligned_l;
p = ctx->device->pipeline_matmul_f16_f32->a_l;
shname = "F16_F32_ALIGNED_L";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_aligned_l;
p = ctx->device->pipeline_matmul_f16->a_l;
shname = "F16_ALIGNED_L";
} else {
GGML_ASSERT(false);
@ -3641,35 +3566,35 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
if (k != kpad) {
if (shader_size == 0) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_s;
p = ctx->device->pipeline_matmul_f32->s;
shname = "F32_S";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_s;
p = ctx->device->pipeline_matmul_f16_f32->s;
shname = "F16_F32_S";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_s;
p = ctx->device->pipeline_matmul_f16->s;
shname = "F16_S";
}
} else if (shader_size == 1) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_m;
p = ctx->device->pipeline_matmul_f32->m;
shname = "F32_M";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_m;
p = ctx->device->pipeline_matmul_f16_f32->m;
shname = "F16_F32_M";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_m;
p = ctx->device->pipeline_matmul_f16->m;
shname = "F16_M";
}
} else if (shader_size == 2) {
if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f32_l;
p = ctx->device->pipeline_matmul_f32->l;
shname = "F32_L";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_f32_l;
p = ctx->device->pipeline_matmul_f16_f32->l;
shname = "F16_F32_L";
} else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
p = ctx->device->pipeline_matmul_f16_l;
p = ctx->device->pipeline_matmul_f16->l;
shname = "F16_L";
}
}
@ -4181,7 +4106,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
free(x_chk);
}
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, ggml_type quant) {
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
#endif
@ -4189,6 +4114,38 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
const size_t y_ne = k * n * batch;
const size_t d_ne = m * n * batch;
vk_pipeline p;
std::string shname;
if (shader_size == 0) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
} else if (shader_size == 1) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
} else if (shader_size == 2) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
if (k != kpad) {
if (shader_size == 0) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
shname = std::string(ggml_type_name(quant)) + "_S";
} else if (shader_size == 1) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
shname = std::string(ggml_type_name(quant)) + "_M";
} else if (shader_size == 2) {
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
shname = std::string(ggml_type_name(quant)) + "_L";
} else {
GGML_ASSERT(0);
}
}
const size_t x_sz = sizeof(float) * x_ne;
const size_t y_sz = sizeof(float) * y_ne;
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
@ -4206,8 +4163,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
x[i] = rand() / (float)RAND_MAX;
}
vk_pipeline p = ctx->device->pipeline_dequant_mul_mat_mat[quant];
ggml_vk_quantize_data(x, qx, x_ne, quant);
for (size_t i = 0; i < y_ne; i++) {
@ -4294,18 +4249,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
avg_err /= m * n;
std::cerr << "TEST MMQ " << ggml_type_name(quant) << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
if (avg_err > 0.1 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 25, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 35, first_err_b);
std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
@ -4499,8 +4449,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
// ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
@ -5179,6 +5133,8 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
return true;
}
return false;
UNUSED(buffer);
}
GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@ -6022,7 +5978,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, *(float *)tensor->op_params);
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, nullptr, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else {
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
}

View file

@ -2302,6 +2302,7 @@ async def main():
stream.clear()
stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2))
tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
# Shaders where precision is needed, so no fp16 version