Further work towards MoE, disabled for now

This commit is contained in:
0cc4m 2024-05-04 09:27:32 +02:00
parent 1e46fa8dce
commit 3098206b00
3 changed files with 117413 additions and 106886 deletions

File diff suppressed because it is too large Load diff

View file

@ -109,7 +109,6 @@ struct vk_device {
uint32_t descriptor_set_mode; uint32_t descriptor_set_mode;
uint32_t subgroup_size; uint32_t subgroup_size;
bool uma; bool uma;
bool buffer_device_address;
bool initialized; bool initialized;
size_t idx; size_t idx;
@ -129,6 +128,7 @@ struct vk_device {
vk_pipeline pipeline_dequant[VK_NUM_TYPES]; vk_pipeline pipeline_dequant[VK_NUM_TYPES];
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES]; vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[VK_NUM_TYPES];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@ -235,6 +235,8 @@ struct vk_mat_vec_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t expert_stride_b; uint32_t expert_stride_d;
uint32_t idx; uint32_t nbi1; uint32_t n_as;
}; };
struct vk_op_push_constants { struct vk_op_push_constants {
@ -1025,8 +1027,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
if (ctx->device->buffer_device_address) { /*ctx->device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
@ -1038,8 +1039,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>(); ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();*/
}
if (device->fp16) { if (device->fp16) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
@ -1133,8 +1133,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
if (ctx->device->buffer_device_address) { /*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
@ -1223,8 +1222,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/
}
} else { } else {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
@ -1317,8 +1315,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
if (ctx->device->buffer_device_address) { /*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
@ -1407,8 +1404,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/
}
} }
// mul mat vec // mul mat vec
@ -1424,6 +1420,18 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
/*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_K_f32", mul_mat_vec_id_q2_K_f32_len, mul_mat_vec_id_q2_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_K_f32", mul_mat_vec_id_q3_K_f32_len, mul_mat_vec_id_q3_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_K_f32", mul_mat_vec_id_q4_K_f32_len, mul_mat_vec_id_q4_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_K_f32", mul_mat_vec_id_q5_K_f32_len, mul_mat_vec_id_q5_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_K_f32", mul_mat_vec_id_q6_K_f32_len, mul_mat_vec_id_q6_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);*/
// dequant shaders // dequant shaders
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@ -1523,15 +1531,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool bda = false;
for (auto properties : ext_props) { for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
fp16_storage = true; fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true; fp16_compute = true;
} else if (strcmp("VK_KHR_buffer_device_address", properties.extensionName) == 0) {
bda = true;
} }
} }
@ -1560,10 +1565,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16; fp16 = fp16 && vk12_features.shaderFloat16;
bda = bda && vk12_features.bufferDeviceAddress;
std::string device_name = props2.properties.deviceName.data(); std::string device_name = props2.properties.deviceName.data();
std::cerr << GGML_VK_NAME << idx << ": " << device_name << " | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << " | buffer device address support: " << bda << std::endl; std::cerr << GGML_VK_NAME << idx << ": " << device_name << " | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl;
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl; std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
@ -1735,15 +1739,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool bda = false;
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
fp16_storage = true; fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true; fp16_compute = true;
} else if (strcmp("VK_KHR_buffer_device_address", properties.extensionName) == 0) {
bda = true;
} }
} }
@ -1792,17 +1793,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16; ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
ctx->device->buffer_device_address = bda && vk12_features.bufferDeviceAddress;
if (!vk11_features.storageBuffer16BitAccess) { if (!vk11_features.storageBuffer16BitAccess) {
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
throw std::runtime_error("Unsupported device"); throw std::runtime_error("Unsupported device");
} }
if (!ctx->device->buffer_device_address) {
std::cerr << "ggml_vulkan: Warning: VK_KHR_buffer_device_address extension not supported. Mixture of Experts models unavailable." << std::endl;
}
device_extensions.push_back("VK_KHR_16bit_storage"); device_extensions.push_back("VK_KHR_16bit_storage");
#ifdef GGML_VULKAN_VALIDATE #ifdef GGML_VULKAN_VALIDATE
@ -1812,9 +1808,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
if (ctx->device->fp16) { if (ctx->device->fp16) {
device_extensions.push_back("VK_KHR_shader_float16_int8"); device_extensions.push_back("VK_KHR_shader_float16_int8");
} }
if (ctx->device->buffer_device_address) {
device_extensions.push_back("VK_KHR_buffer_device_address");
}
ctx->device->name = ctx->device->properties.deviceName.data(); ctx->device->name = ctx->device->properties.deviceName.data();
device_create_info = { device_create_info = {
@ -2729,6 +2722,33 @@ static void ggml_vk_matmul(
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
} }
static void ggml_vk_matmul_id(
ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
vk_subbuffer&& ids, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& a, vk_subbuffer&& split_k_buffer,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t expert_stride_b, uint32_t expert_stride_d, uint32_t idx, uint32_t nbi1, uint32_t n_as) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << split_k_buffer.buffer->buffer << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
#endif
ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, d, a }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
return;
}
GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, split_k_buffer, a }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
}
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
return return
tensor->nb[0] == ggml_type_size(tensor->type) && tensor->nb[0] == ggml_type_size(tensor->type) &&
@ -3104,6 +3124,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21), stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
0, 0, 0, 0, 1
}; };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1}); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
@ -3288,7 +3309,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU); ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
} }
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl; std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl;
#endif #endif
@ -3303,15 +3324,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
} }
} }
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
#endif #endif
GGML_ASSERT(src0->type == GGML_TYPE_I32); GGML_ASSERT(src0->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ctx->device->buffer_device_address);
const uint64_t ne00 = src0->ne[0]; const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1]; const uint64_t ne01 = src0->ne[1];
@ -3333,13 +3354,14 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
const uint32_t nbi1 = src0->nb[1]; const uint32_t nbi1 = src0->nb[1];
const uint32_t idx = ((uint32_t *) dst->op_params)[0]; const uint32_t idx = ((uint32_t *) dst->op_params)[0];
const uint32_t n_as = ((uint32_t *) dst->op_params)[1]; const uint64_t n_as = ne02;
GGML_ASSERT(n_as <= 8); GGML_ASSERT(n_as <= 8);
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
vk_buffer d_Qx; vk_buffer d_Qx;
size_t qx_buf_offset = 0; size_t qx_buf_offset = 0;
@ -3485,6 +3507,165 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
); // NOLINT ); // NOLINT
} }
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
#endif
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1];
const uint64_t ne02 = src0->ne[2];
const uint64_t ne03 = src0->ne[3];
const uint64_t ne10 = src1->ne[0];
const uint64_t ne11 = src1->ne[1];
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];
GGML_ASSERT(ne11 == 1);
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint64_t ne22 = dst->ne[2];
const uint64_t ne23 = dst->ne[3];
const uint64_t nb22 = dst->nb[2];
const uint64_t nb23 = dst->nb[3];
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
vk_buffer d_Qx;
size_t qx_buf_offset = 0;
vk_buffer d_Qy;
size_t qy_buf_offset = 0;
bool src0_uma = false;
bool src1_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
src0_uma = d_Qx != nullptr;
src1_uma = d_Qy != nullptr;
}
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne11 * ne01;
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_buffer d_D = extra->buffer_gpu.lock();
const uint64_t d_buf_offset = extra->offset;
GGML_ASSERT(d_D != nullptr);
vk_buffer d_X;
uint64_t x_buf_offset = 0;
vk_buffer d_Y;
uint64_t y_buf_offset = 0;
if(!src0_uma) {
d_Qx = extra_src0->buffer_gpu.lock();
qx_buf_offset = extra_src0->offset;
GGML_ASSERT(d_Qx != nullptr);
}
if(!src1_uma) {
d_Qy = extra_src1->buffer_gpu.lock();
qy_buf_offset = extra_src1->offset;
GGML_ASSERT(d_Qy != nullptr);
}
if (qx_needs_dequant) {
d_X = ctx->prealloc_x;
} else {
d_X = d_Qx;
x_buf_offset = qx_buf_offset;
GGML_ASSERT(qx_sz == x_sz);
}
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
GGML_ASSERT(qy_sz == y_sz);
}
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
}
if (y_non_contig) {
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
// Allocate descriptor sets
if (qx_needs_dequant) {
ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
if (qy_needs_dequant) {
ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
}
ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
}
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
}
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
0, 0, 0, 0, 1
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
}
static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
// guaranteed to be an integer due to the check in ggml_can_repeat // guaranteed to be an integer due to the check in ggml_can_repeat
const uint64_t ne0 = dst->ne[0]; const uint64_t ne0 = dst->ne[0];
@ -6052,12 +6233,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
} }
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: // case GGML_OP_MUL_MAT_ID:
{ {
if (!ctx->device->buffer_device_address) {
return false;
}
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:

View file

@ -164,47 +164,54 @@ struct block_q6_K
# Dequant functions # Dequant functions
shader_float_dequant_func = """ shader_float_dequant_func = """
#define DEQUANT_FUNC vec2 v = vec2(ib, ib); // data_a[ib], data_a[ib + 1]); vec2 dequantize(uint ib, uint iqs) {
return vec2(data_a[ib], data_a[ib + 1]);
}
""" """
shader_q4_0_dequant_func = """ shader_q4_0_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ vec2 dequantize(uint ib, uint iqs) {
const uint vui = uint(data_a[ib].qs[iqs]); \ const float d = float(data_a[ib].d);
vec2 v = vec2(vui & 0xF, vui >> 4); \ const uint vui = uint(data_a[ib].qs[iqs]);
v = (v - 8.0f)*d; return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
}
""" """
shader_q4_1_dequant_func = """ shader_q4_1_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ vec2 dequantize(uint ib, uint iqs) {
const float m = float(data_a[ib].m); \ const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]); \ const float m = float(data_a[ib].m);
vec2 v = vec2(vui & 0xF, vui >> 4); \ const uint vui = uint(data_a[ib].qs[iqs]);
v = v*d + m; return vec2(vui & 0xF, vui >> 4) * d + m;
}
""" """
shader_q5_0_dequant_func = """ shader_q5_0_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ vec2 dequantize(uint ib, uint iqs) {
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \ const float d = float(data_a[ib].d);
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const uint vui = uint(data_a[ib].qs[iqs]); \ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ const uint vui = uint(data_a[ib].qs[iqs]);
v = (v - 16.0f) * d; return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
}
""" """
shader_q5_1_dequant_func = """ shader_q5_1_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ vec2 dequantize(uint ib, uint iqs) {
const float m = float(data_a[ib].m); \ const float d = float(data_a[ib].d);
const uint uint_qh = data_a[ib].qh; \ const float m = float(data_a[ib].m);
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ const uint uint_qh = data_a[ib].qh;
const uint vui = uint(data_a[ib].qs[iqs]); \ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ const uint vui = uint(data_a[ib].qs[iqs]);
v = v*d + m; return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
}
""" """
shader_q8_0_dequant_func = """ shader_q8_0_dequant_func = """
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ vec2 dequantize(uint ib, uint iqs) {
vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \ const float d = float(data_a[ib].d);
v = v * d; return vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
}
""" """
# MULMAT # MULMAT
@ -218,6 +225,7 @@ mulmat_head = """#version 450
#extension GL_EXT_buffer_reference2 : require #extension GL_EXT_buffer_reference2 : require
#extension GL_EXT_nonuniform_qualifier : require #extension GL_EXT_nonuniform_qualifier : require
#extension GL_EXT_scalar_block_layout : require #extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#define EXPERT_COUNT 8 #define EXPERT_COUNT 8
#endif #endif
@ -233,21 +241,12 @@ mulmat_head = """#version 450
mulmat_body1 = """ mulmat_body1 = """
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#ifdef MUL_MAT_ID
layout (binding = 0) readonly buffer IDS {int data_ids[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
layout (buffer_reference) readonly buffer InputA {A_TYPE data[];}; layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 3) readonly buffer tensor_input_a {InputA data_a_ptr[EXPERT_COUNT];};
#define DATA_A data_a.data
#else
#define DATA_A data_a
#endif #endif
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
@ -268,12 +267,21 @@ layout (push_constant) uniform parameter
uint batch_stride_a; uint batch_stride_a;
uint batch_stride_b; uint batch_stride_b;
uint batch_stride_d; uint batch_stride_d;
uint expert_stride_b;
#ifdef MUL_MAT_ID
uint expert_stride_a;
uint expert_stride_b0;
uint expert_stride_b1;
uint expert_stride_d; uint expert_stride_d;
uint idx; uint ids_stride;
uint nbi1;
uint n_as; uint n_as;
uint nei0;
uint nei1;
uint nbi1;
uint ne11;
#endif
} p; } p;
layout (constant_id = 1) const uint BM = 64; layout (constant_id = 1) const uint BM = 64;
@ -289,9 +297,17 @@ layout (constant_id = 9) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)]; shared FLOAT_TYPE buf_a[BM * (BK+1)];
shared FLOAT_TYPE buf_b[BN * (BK+1)]; shared FLOAT_TYPE buf_b[BN * (BK+1)];
#ifdef MUL_MAT_ID
shared u8vec2 rowids[2048];
#endif
void main() { void main() {
#ifdef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z / p.n_as; const uint batch_idx = gl_GlobalInvocationID.z / p.n_as;
const uint expert_idx = gl_GlobalInvocationID.z % p.n_as; const uint expert_idx = gl_GlobalInvocationID.z % p.n_as;
#else
const uint batch_idx = gl_GlobalInvocationID.z;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -327,15 +343,34 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const uint expert_id = data_ids[p.idx + expert_idx * p.nbi1]; uint _ne1 = 0;
InputA data_a = data_a_ptr[expert_id]; for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
rowids[_ne1] = u8vec2(ii0, ii1);
_ne1++;
}
}
}
const u8vec2 id = rowids[ir * BN + ic];
#endif #endif
const uint start_k = ik * p.k_split; const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split); const uint end_k = min(p.K, (ik + 1) * p.k_split);
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; uint pos_a = (
uint pos_b = (expert_idx * p.expert_stride_b + batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; #ifdef MUL_MAT_ID
expert_idx * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
uint pos_b = (
#ifdef MUL_MAT_ID
id.y * p.expert_stride_b1 +
(id.x % p.ne11) * p.expert_stride_b0 +
#endif
batch_idx * p.batch_stride_b +
ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
float sums[WMITER * TM * WNITER * TN]; float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
@ -352,24 +387,24 @@ mulmat_load_scalar = """
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(DATA_A[idx][0].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(DATA_A[idx][0].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(DATA_A[idx][0].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(DATA_A[idx][0].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[buf_idx + 4] = FLOAT_TYPE(DATA_A[idx][1].x); buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[buf_idx + 5] = FLOAT_TYPE(DATA_A[idx][1].y); buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[buf_idx + 6] = FLOAT_TYPE(DATA_A[idx][1].z); buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(DATA_A[idx][1].w); buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(DATA_A[idx].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(DATA_A[idx].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(DATA_A[idx].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(DATA_A[idx].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(DATA_A[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else { } else {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
} }
@ -383,8 +418,8 @@ mulmat_load_q4_0 = """
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
const float d = float(DATA_A[ib].d); const float d = float(data_a[ib].d);
const uint vui = uint(DATA_A[ib].qs[iqs]); const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@ -397,9 +432,9 @@ mulmat_load_q4_1 = """
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
const float d = float(DATA_A[ib].d); const float d = float(data_a[ib].d);
const float m = float(DATA_A[ib].m); const float m = float(data_a[ib].m);
const uint vui = uint(DATA_A[ib].qs[iqs]); const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@ -412,10 +447,10 @@ mulmat_load_q5_0 = """
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
const float d = float(DATA_A[ib].d); const float d = float(data_a[ib].d);
const uint uint_qh = uint(DATA_A[ib].qh[1]) << 16 | DATA_A[ib].qh[0]; const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(DATA_A[ib].qs[iqs]); const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@ -428,11 +463,11 @@ mulmat_load_q5_1 = """
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
const float d = float(DATA_A[ib].d); const float d = float(data_a[ib].d);
const float m = float(DATA_A[ib].m); const float m = float(data_a[ib].m);
const uint uint_qh = DATA_A[ib].qh; const uint uint_qh = data_a[ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(DATA_A[ib].qs[iqs]); const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@ -445,8 +480,8 @@ mulmat_load_q8_0 = """
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2; const uint iqs = (idx & 0xF) * 2;
const float d = float(DATA_A[ib].d); const float d = float(data_a[ib].d);
const vec2 v = vec2(int(DATA_A[ib].qs[iqs]), int(DATA_A[ib].qs[iqs + 1])) * d; const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
@ -463,9 +498,9 @@ mulmat_load_q2_K = """
const uint scalesi = iqs / 8; // 0..15 const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(DATA_A[ib].qs[qsi], DATA_A[ib].qs[qsi + 1]); const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uint scales = DATA_A[ib].scales[scalesi]; const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(DATA_A[ib].d); const vec2 d = vec2(data_a[ib].d);
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
@ -488,14 +523,14 @@ mulmat_load_q3_K = """
const uint qsshift = halfsplit * 2; // 0,2,4,6 const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(is < 4 ? (DATA_A[ib].scales[is-0] & 0xF) | (((DATA_A[ib].scales[is+8] >> 0) & 3) << 4) : const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (DATA_A[ib].scales[is-0] & 0xF) | (((DATA_A[ib].scales[is+4] >> 2) & 3) << 4) : is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (DATA_A[ib].scales[is-8] >> 4) | (((DATA_A[ib].scales[is+0] >> 4) & 3) << 4) : is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
(DATA_A[ib].scales[is-8] >> 4) | (((DATA_A[ib].scales[is-4] >> 6) & 3) << 4)); (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
const float dl = float(DATA_A[ib].d) * float(us - 32); const float dl = float(data_a[ib].d) * float(us - 32);
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((DATA_A[ib].qs[qsi ] >> qsshift) & 3) - (((DATA_A[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((DATA_A[ib].qs[qsi + 1] >> qsshift) & 3) - (((DATA_A[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));""" buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));"""
mulmat_load_q4_K = """ mulmat_load_q4_K = """
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
@ -509,22 +544,22 @@ mulmat_load_q4_K = """
const uint is = 2 * n + b; // 0..7 const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(DATA_A[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; uint8_t sc;
uint8_t mbyte; uint8_t mbyte;
if (is < 4) { if (is < 4) {
sc = uint8_t(DATA_A[ib].scales[is ] & 63); sc = uint8_t(data_a[ib].scales[is ] & 63);
mbyte = uint8_t(DATA_A[ib].scales[is + 4] & 63); mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
} else { } else {
sc = uint8_t((DATA_A[ib].scales[is + 4] & 0xF) | ((DATA_A[ib].scales[is - 4] >> 6) << 4)); sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
mbyte = uint8_t((DATA_A[ib].scales[is + 4] >> 4) | ((DATA_A[ib].scales[is ] >> 6) << 4)); mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
} }
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = loadd.y * mbyte; const float m = loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(d * float((DATA_A[ib].qs[qsi ] >> (b * 4)) & 0xF) - m); buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((DATA_A[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);""" buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);"""
mulmat_load_q5_K = """ mulmat_load_q5_K = """
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
@ -541,22 +576,22 @@ mulmat_load_q5_K = """
const uint8_t hm = uint8_t(1 << (iqs / 16)); const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(DATA_A[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; uint8_t sc;
uint8_t mbyte; uint8_t mbyte;
if (is < 4) { if (is < 4) {
sc = uint8_t(DATA_A[ib].scales[is ] & 63); sc = uint8_t(data_a[ib].scales[is ] & 63);
mbyte = uint8_t(DATA_A[ib].scales[is + 4] & 63); mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
} else { } else {
sc = uint8_t((DATA_A[ib].scales[is + 4] & 0xF) | ((DATA_A[ib].scales[is - 4] >> 6) << 4)); sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
mbyte = uint8_t((DATA_A[ib].scales[is + 4] >> 4) | ((DATA_A[ib].scales[is ] >> 6) << 4)); mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
} }
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = loadd.y * mbyte; const float m = loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(d * (float((DATA_A[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((DATA_A[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m); buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((DATA_A[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((DATA_A[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);""" buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);"""
mulmat_load_q6_K = """ mulmat_load_q6_K = """
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
@ -573,10 +608,10 @@ mulmat_load_q6_K = """
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const float dscale = float(DATA_A[ib].d) * float(DATA_A[ib].scales[is]); const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((DATA_A[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((DATA_A[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((DATA_A[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((DATA_A[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));""" buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));"""
mulmat_body2 = """ mulmat_body2 = """
} }
@ -643,7 +678,11 @@ mulmat_body2 = """
const uint dr = ir * BM + warp_r * WM; const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN; const uint dc = ic * BN + warp_c * WN;
const uint offsets = expert_idx * p.expert_stride_d + batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; const uint offsets =
#ifdef MUL_MAT_ID
expert_idx * p.expert_stride_d +
#endif
batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -1109,6 +1148,20 @@ mul_mat_vec_head = """#version 450
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
#endif
"""
mul_mat_vec_layout = """
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint ncols; uint ncols;
@ -1124,16 +1177,25 @@ layout (push_constant) uniform parameter
uint batch_stride_a; uint batch_stride_a;
uint batch_stride_b; uint batch_stride_b;
uint batch_stride_d; uint batch_stride_d;
#ifdef MUL_MAT_ID
uint expert_stride_a;
uint expert_stride_b0;
uint expert_stride_b1;
uint expert_stride_d0;
uint expert_stride_d1;
uint ne11;
uint nei0;
uint nbi1;
uint n_as;
#endif
} p; } p;
""" """
mul_mat_vec_body = """ mul_mat_vec_body = """
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE]; shared FLOAT_TYPE tmp[BLOCK_SIZE];
@ -1142,6 +1204,10 @@ void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1151,9 +1217,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
@ -1165,7 +1249,7 @@ void main() {
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index const uint iybs = col - col%QUANT_K; // y block start index
DEQUANT_FUNC vec2 v = dequantize(ib, iqs);
// matrix multiplication // matrix multiplication
tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) + tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
@ -1181,7 +1265,7 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
@ -1190,15 +1274,15 @@ void main() {
mul_mat_vec_q2_K_body = """ mul_mat_vec_q2_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1208,9 +1292,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
@ -1268,22 +1370,22 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q3_K_body = """ mul_mat_vec_q3_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1293,9 +1395,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
@ -1346,22 +1466,22 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q4_K_body = """ mul_mat_vec_q4_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1371,9 +1491,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
@ -1473,22 +1611,22 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q5_K_body = """ mul_mat_vec_q5_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1498,9 +1636,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
@ -1596,22 +1752,22 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
mul_mat_vec_q6_K_body = """ mul_mat_vec_q6_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
shared FLOAT_TYPE tmp[32]; shared FLOAT_TYPE tmp[32];
void main() { void main() {
const uint row = gl_WorkGroupID.x; const uint row = gl_WorkGroupID.x;
const uint batch_idx = gl_GlobalInvocationID.y; const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
#endif
const uint i13 = batch_idx / p.ne12; const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12; const uint i12 = batch_idx % p.ne12;
@ -1621,9 +1777,27 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02; const uint batch_idx_a = i03 * p.ne02 + i02;
const uint a_offset = batch_idx_a * p.batch_stride_a; #ifdef MUL_MAT_ID
const uint b_offset = batch_idx * p.batch_stride_b; const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
const uint d_offset = batch_idx * p.batch_stride_d; #endif
const uint a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a +
#endif
batch_idx_a * p.batch_stride_a;
const uint b_offset =
#ifdef MUL_MAT_ID
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_b;
const uint d_offset =
#ifdef MUL_MAT_ID
expert_idx0 * p.expert_stride_b0 +
expert_idx1 * p.expert_stride_b1 +
#endif
batch_idx * p.batch_stride_d;
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
@ -1687,7 +1861,7 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[d_offset + row] = D_TYPE(tmp[0]); data_d[d_offset + row] = D_TYPE(tmp[0]);
} }
} }
""" """
@ -1868,12 +2042,13 @@ layout (push_constant) uniform parameter
float param1; float param2; float param1; float param2;
} p;""" } p;"""
generic_unary_op_funcs = """ generic_unary_op_layout = """
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};"""
generic_unary_op_funcs = """
uint src0_idx(uint idx) { uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
@ -1901,7 +2076,7 @@ void main() {
} }
""" """
generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_funcs}\n{generic_unary_op_main}" generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}"
generic_binary_op_head = """#version 450 generic_binary_op_head = """#version 450
@ -1917,13 +2092,14 @@ layout (push_constant) uniform parameter
float param1; float param2; float param1; float param2;
} p;""" } p;"""
generic_binary_op_funcs = """ generic_binary_op_layout = """
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};"""
generic_binary_op_funcs = """
uint src0_idx(uint idx) { uint src0_idx(uint idx) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
@ -1962,7 +2138,7 @@ void main() {
} }
""" """
generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
# MUL F32 # MUL F32
mul_body = """ mul_body = """
@ -2053,7 +2229,7 @@ void main() {
const uint iybs = i00 - i00%QUANT_K; // dst block start index const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
DEQUANT_FUNC vec2 v = dequantize(ib, iqs);
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
@ -2792,31 +2968,32 @@ async def main():
stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
if i == GGML_TYPE_F16: if i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, shader_float_dequant_func, mul_mat_vec_body)) stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_0: elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body)) stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q4_1: elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body)) stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_0: elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body)) stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q5_1: elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body)) stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q8_0: elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body)) stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body))
elif i == GGML_TYPE_Q2_K: elif i == GGML_TYPE_Q2_K:
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body)) stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body))
elif i == GGML_TYPE_Q3_K: elif i == GGML_TYPE_Q3_K:
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body)) stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body))
elif i == GGML_TYPE_Q4_K: elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body)) stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body))
elif i == GGML_TYPE_Q5_K: elif i == GGML_TYPE_Q5_K:
stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body)) stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body))
elif i == GGML_TYPE_Q6_K: elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body)) stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body))
else: else:
continue continue
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
# Dequant shaders # Dequant shaders
for i in range(0, VK_NUM_TYPES): for i in range(0, VK_NUM_TYPES):
@ -2858,20 +3035,20 @@ async def main():
optimization_workaround = False optimization_workaround = False
if i == GGML_TYPE_F32: if i == GGML_TYPE_F32:
stream.extend((shader_f32_defines, generic_binary_op_funcs, get_rows_float_body)) stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
elif i == GGML_TYPE_F16: elif i == GGML_TYPE_F16:
stream.extend((shader_f16_defines, generic_binary_op_funcs, get_rows_float_body)) stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
optimization_workaround = True optimization_workaround = True
elif i == GGML_TYPE_Q4_0: elif i == GGML_TYPE_Q4_0:
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body)) stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body))
elif i == GGML_TYPE_Q4_1: elif i == GGML_TYPE_Q4_1:
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body)) stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body))
elif i == GGML_TYPE_Q5_0: elif i == GGML_TYPE_Q5_0:
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body)) stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body))
elif i == GGML_TYPE_Q5_1: elif i == GGML_TYPE_Q5_1:
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body)) stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body))
elif i == GGML_TYPE_Q8_0: elif i == GGML_TYPE_Q8_0:
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body)) stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body))
else: else:
continue continue