Further work towards MoE, disabled for now
This commit is contained in:
parent
1e46fa8dce
commit
3098206b00
3 changed files with 117413 additions and 106886 deletions
223199
ggml-vulkan-shaders.hpp
223199
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
253
ggml-vulkan.cpp
253
ggml-vulkan.cpp
|
@ -109,7 +109,6 @@ struct vk_device {
|
||||||
uint32_t descriptor_set_mode;
|
uint32_t descriptor_set_mode;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
bool uma;
|
bool uma;
|
||||||
bool buffer_device_address;
|
|
||||||
|
|
||||||
bool initialized;
|
bool initialized;
|
||||||
size_t idx;
|
size_t idx;
|
||||||
|
@ -129,6 +128,7 @@ struct vk_device {
|
||||||
|
|
||||||
vk_pipeline pipeline_dequant[VK_NUM_TYPES];
|
vk_pipeline pipeline_dequant[VK_NUM_TYPES];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
|
vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
|
||||||
|
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[VK_NUM_TYPES];
|
||||||
|
|
||||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
||||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||||
|
@ -235,6 +235,8 @@ struct vk_mat_vec_push_constants {
|
||||||
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||||
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
||||||
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
||||||
|
uint32_t expert_stride_b; uint32_t expert_stride_d;
|
||||||
|
uint32_t idx; uint32_t nbi1; uint32_t n_as;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_push_constants {
|
struct vk_op_push_constants {
|
||||||
|
@ -1025,8 +1027,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
|
||||||
if (ctx->device->buffer_device_address) {
|
/*ctx->device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
||||||
ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
@ -1038,8 +1039,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();*/
|
||||||
}
|
|
||||||
|
|
||||||
if (device->fp16) {
|
if (device->fp16) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
||||||
|
@ -1133,8 +1133,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
|
|
||||||
if (ctx->device->buffer_device_address) {
|
/*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_len, matmul_id_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_len, matmul_id_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
|
||||||
|
@ -1223,8 +1222,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
||||||
|
@ -1317,8 +1315,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
|
|
||||||
if (ctx->device->buffer_device_address) {
|
/*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_fp32_len, matmul_id_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_aligned_fp32_len, matmul_id_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
|
||||||
|
@ -1407,8 +1404,7 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);*/
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mul mat vec
|
// mul mat vec
|
||||||
|
@ -1424,6 +1420,18 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
|
/*ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_K_f32", mul_mat_vec_id_q2_K_f32_len, mul_mat_vec_id_q2_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_K_f32", mul_mat_vec_id_q3_K_f32_len, mul_mat_vec_id_q3_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_K_f32", mul_mat_vec_id_q4_K_f32_len, mul_mat_vec_id_q4_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_K_f32", mul_mat_vec_id_q5_K_f32_len, mul_mat_vec_id_q5_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_K_f32", mul_mat_vec_id_q6_K_f32_len, mul_mat_vec_id_q6_K_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);*/
|
||||||
|
|
||||||
// dequant shaders
|
// dequant shaders
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
|
@ -1523,15 +1531,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
bool fp16_compute = false;
|
bool fp16_compute = false;
|
||||||
bool bda = false;
|
|
||||||
|
|
||||||
for (auto properties : ext_props) {
|
for (auto properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
||||||
fp16_storage = true;
|
fp16_storage = true;
|
||||||
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
||||||
fp16_compute = true;
|
fp16_compute = true;
|
||||||
} else if (strcmp("VK_KHR_buffer_device_address", properties.extensionName) == 0) {
|
|
||||||
bda = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1560,10 +1565,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
||||||
|
|
||||||
fp16 = fp16 && vk12_features.shaderFloat16;
|
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||||
bda = bda && vk12_features.bufferDeviceAddress;
|
|
||||||
|
|
||||||
std::string device_name = props2.properties.deviceName.data();
|
std::string device_name = props2.properties.deviceName.data();
|
||||||
std::cerr << GGML_VK_NAME << idx << ": " << device_name << " | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << " | buffer device address support: " << bda << std::endl;
|
std::cerr << GGML_VK_NAME << idx << ": " << device_name << " | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl;
|
||||||
|
|
||||||
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
||||||
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
|
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
|
||||||
|
@ -1735,15 +1739,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
bool fp16_compute = false;
|
bool fp16_compute = false;
|
||||||
bool bda = false;
|
|
||||||
|
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
||||||
fp16_storage = true;
|
fp16_storage = true;
|
||||||
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
||||||
fp16_compute = true;
|
fp16_compute = true;
|
||||||
} else if (strcmp("VK_KHR_buffer_device_address", properties.extensionName) == 0) {
|
|
||||||
bda = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1792,17 +1793,12 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||||
vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
|
vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
|
||||||
|
|
||||||
ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
|
ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
|
||||||
ctx->device->buffer_device_address = bda && vk12_features.bufferDeviceAddress;
|
|
||||||
|
|
||||||
if (!vk11_features.storageBuffer16BitAccess) {
|
if (!vk11_features.storageBuffer16BitAccess) {
|
||||||
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
||||||
throw std::runtime_error("Unsupported device");
|
throw std::runtime_error("Unsupported device");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ctx->device->buffer_device_address) {
|
|
||||||
std::cerr << "ggml_vulkan: Warning: VK_KHR_buffer_device_address extension not supported. Mixture of Experts models unavailable." << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
device_extensions.push_back("VK_KHR_16bit_storage");
|
device_extensions.push_back("VK_KHR_16bit_storage");
|
||||||
|
|
||||||
#ifdef GGML_VULKAN_VALIDATE
|
#ifdef GGML_VULKAN_VALIDATE
|
||||||
|
@ -1812,9 +1808,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||||
if (ctx->device->fp16) {
|
if (ctx->device->fp16) {
|
||||||
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
||||||
}
|
}
|
||||||
if (ctx->device->buffer_device_address) {
|
|
||||||
device_extensions.push_back("VK_KHR_buffer_device_address");
|
|
||||||
}
|
|
||||||
ctx->device->name = ctx->device->properties.deviceName.data();
|
ctx->device->name = ctx->device->properties.deviceName.data();
|
||||||
|
|
||||||
device_create_info = {
|
device_create_info = {
|
||||||
|
@ -2729,6 +2722,33 @@ static void ggml_vk_matmul(
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_matmul_id(
|
||||||
|
ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
|
||||||
|
vk_subbuffer&& ids, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& a, vk_subbuffer&& split_k_buffer,
|
||||||
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
||||||
|
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
||||||
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
||||||
|
uint32_t expert_stride_b, uint32_t expert_stride_d, uint32_t idx, uint32_t nbi1, uint32_t n_as) {
|
||||||
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
|
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << split_k_buffer.buffer->buffer << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
|
||||||
|
#endif
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
if (split_k == 1) {
|
||||||
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as };
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, d, a }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(batch_stride_d == m * n);
|
||||||
|
|
||||||
|
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, expert_stride_b, expert_stride_d, idx, nbi1, n_as };
|
||||||
|
// Make sure enough workgroups get assigned for split k to work
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ids, b, split_k_buffer, a }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
||||||
|
}
|
||||||
|
|
||||||
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
||||||
return
|
return
|
||||||
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||||
|
@ -3104,6 +3124,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
|
||||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||||
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
|
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
|
||||||
|
0, 0, 0, 0, 1
|
||||||
};
|
};
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
|
||||||
|
@ -3288,7 +3309,7 @@ static bool ggml_vk_can_mul_mat(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
|
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_TYPE_GPU);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl;
|
std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -3303,15 +3324,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||||
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
||||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_I32);
|
GGML_ASSERT(src0->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
||||||
GGML_ASSERT(ctx->device->buffer_device_address);
|
|
||||||
|
|
||||||
const uint64_t ne00 = src0->ne[0];
|
const uint64_t ne00 = src0->ne[0];
|
||||||
const uint64_t ne01 = src0->ne[1];
|
const uint64_t ne01 = src0->ne[1];
|
||||||
|
@ -3333,13 +3354,14 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
|
|
||||||
const uint32_t nbi1 = src0->nb[1];
|
const uint32_t nbi1 = src0->nb[1];
|
||||||
const uint32_t idx = ((uint32_t *) dst->op_params)[0];
|
const uint32_t idx = ((uint32_t *) dst->op_params)[0];
|
||||||
const uint32_t n_as = ((uint32_t *) dst->op_params)[1];
|
const uint64_t n_as = ne02;
|
||||||
|
|
||||||
GGML_ASSERT(n_as <= 8);
|
GGML_ASSERT(n_as <= 8);
|
||||||
|
|
||||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||||
|
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
|
||||||
|
|
||||||
vk_buffer d_Qx;
|
vk_buffer d_Qx;
|
||||||
size_t qx_buf_offset = 0;
|
size_t qx_buf_offset = 0;
|
||||||
|
@ -3485,6 +3507,165 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
); // NOLINT
|
); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
|
std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||||
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||||
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
||||||
|
#endif
|
||||||
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
||||||
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
||||||
|
|
||||||
|
const uint64_t ne00 = src0->ne[0];
|
||||||
|
const uint64_t ne01 = src0->ne[1];
|
||||||
|
const uint64_t ne02 = src0->ne[2];
|
||||||
|
const uint64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const uint64_t ne10 = src1->ne[0];
|
||||||
|
const uint64_t ne11 = src1->ne[1];
|
||||||
|
const uint64_t ne12 = src1->ne[2];
|
||||||
|
const uint64_t ne13 = src1->ne[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(ne11 == 1);
|
||||||
|
|
||||||
|
const uint64_t ne20 = dst->ne[0];
|
||||||
|
const uint64_t ne21 = dst->ne[1];
|
||||||
|
const uint64_t ne22 = dst->ne[2];
|
||||||
|
const uint64_t ne23 = dst->ne[3];
|
||||||
|
|
||||||
|
const uint64_t nb22 = dst->nb[2];
|
||||||
|
const uint64_t nb23 = dst->nb[3];
|
||||||
|
|
||||||
|
const uint64_t r2 = ne12 / ne02;
|
||||||
|
const uint64_t r3 = ne13 / ne03;
|
||||||
|
|
||||||
|
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
|
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
|
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||||
|
|
||||||
|
vk_buffer d_Qx;
|
||||||
|
size_t qx_buf_offset = 0;
|
||||||
|
vk_buffer d_Qy;
|
||||||
|
size_t qy_buf_offset = 0;
|
||||||
|
|
||||||
|
bool src0_uma = false;
|
||||||
|
bool src1_uma = false;
|
||||||
|
|
||||||
|
if (ctx->device->uma) {
|
||||||
|
ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
|
||||||
|
ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
|
||||||
|
src0_uma = d_Qx != nullptr;
|
||||||
|
src1_uma = d_Qy != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
||||||
|
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
|
||||||
|
|
||||||
|
const bool qx_needs_dequant = x_non_contig;
|
||||||
|
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
|
||||||
|
|
||||||
|
// Not implemented
|
||||||
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
|
const uint64_t x_ne = ne01 * ne00;
|
||||||
|
const uint64_t y_ne = ne11 * ne10;
|
||||||
|
const uint64_t d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
|
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
|
||||||
|
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
||||||
|
const uint64_t d_sz = sizeof(float) * d_ne;
|
||||||
|
|
||||||
|
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||||
|
const uint64_t d_buf_offset = extra->offset;
|
||||||
|
GGML_ASSERT(d_D != nullptr);
|
||||||
|
vk_buffer d_X;
|
||||||
|
uint64_t x_buf_offset = 0;
|
||||||
|
vk_buffer d_Y;
|
||||||
|
uint64_t y_buf_offset = 0;
|
||||||
|
if(!src0_uma) {
|
||||||
|
d_Qx = extra_src0->buffer_gpu.lock();
|
||||||
|
qx_buf_offset = extra_src0->offset;
|
||||||
|
GGML_ASSERT(d_Qx != nullptr);
|
||||||
|
}
|
||||||
|
if(!src1_uma) {
|
||||||
|
d_Qy = extra_src1->buffer_gpu.lock();
|
||||||
|
qy_buf_offset = extra_src1->offset;
|
||||||
|
GGML_ASSERT(d_Qy != nullptr);
|
||||||
|
}
|
||||||
|
if (qx_needs_dequant) {
|
||||||
|
d_X = ctx->prealloc_x;
|
||||||
|
} else {
|
||||||
|
d_X = d_Qx;
|
||||||
|
x_buf_offset = qx_buf_offset;
|
||||||
|
GGML_ASSERT(qx_sz == x_sz);
|
||||||
|
}
|
||||||
|
if (qy_needs_dequant) {
|
||||||
|
d_Y = ctx->prealloc_y;
|
||||||
|
} else {
|
||||||
|
d_Y = d_Qy;
|
||||||
|
y_buf_offset = qy_buf_offset;
|
||||||
|
GGML_ASSERT(qy_sz == y_sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_pipeline to_fp16_vk_0 = nullptr;
|
||||||
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||||
|
if (x_non_contig) {
|
||||||
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
|
||||||
|
}
|
||||||
|
if (y_non_contig) {
|
||||||
|
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
|
||||||
|
} else {
|
||||||
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
|
}
|
||||||
|
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
|
||||||
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||||
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||||
|
GGML_ASSERT(dmmv != nullptr);
|
||||||
|
|
||||||
|
// Allocate descriptor sets
|
||||||
|
if (qx_needs_dequant) {
|
||||||
|
ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
|
||||||
|
}
|
||||||
|
if (qy_needs_dequant) {
|
||||||
|
ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
|
||||||
|
}
|
||||||
|
ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
|
||||||
|
|
||||||
|
if (x_non_contig) {
|
||||||
|
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
|
||||||
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
||||||
|
}
|
||||||
|
if (y_non_contig) {
|
||||||
|
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
||||||
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t stride_batch_x = ne00*ne01;
|
||||||
|
uint32_t stride_batch_y = ne10*ne11;
|
||||||
|
|
||||||
|
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
|
||||||
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
||||||
|
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute
|
||||||
|
const vk_mat_vec_push_constants pc = {
|
||||||
|
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||||
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||||
|
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
|
||||||
|
0, 0, 0, 0, 1
|
||||||
|
};
|
||||||
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||||
const uint64_t ne0 = dst->ne[0];
|
const uint64_t ne0 = dst->ne[0];
|
||||||
|
@ -6052,12 +6233,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
// case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
if (!ctx->device->buffer_device_address) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
|
|
@ -164,47 +164,54 @@ struct block_q6_K
|
||||||
|
|
||||||
# Dequant functions
|
# Dequant functions
|
||||||
shader_float_dequant_func = """
|
shader_float_dequant_func = """
|
||||||
#define DEQUANT_FUNC vec2 v = vec2(ib, ib); // data_a[ib], data_a[ib + 1]);
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
|
return vec2(data_a[ib], data_a[ib + 1]);
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shader_q4_0_dequant_func = """
|
shader_q4_0_dequant_func = """
|
||||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]); \
|
const float d = float(data_a[ib].d);
|
||||||
vec2 v = vec2(vui & 0xF, vui >> 4); \
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
v = (v - 8.0f)*d;
|
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shader_q4_1_dequant_func = """
|
shader_q4_1_dequant_func = """
|
||||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
const float m = float(data_a[ib].m); \
|
const float d = float(data_a[ib].d);
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]); \
|
const float m = float(data_a[ib].m);
|
||||||
vec2 v = vec2(vui & 0xF, vui >> 4); \
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
v = v*d + m;
|
return vec2(vui & 0xF, vui >> 4) * d + m;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shader_q5_0_dequant_func = """
|
shader_q5_0_dequant_func = """
|
||||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
|
const float d = float(data_a[ib].d);
|
||||||
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
|
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]); \
|
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
v = (v - 16.0f) * d;
|
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shader_q5_1_dequant_func = """
|
shader_q5_1_dequant_func = """
|
||||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
const float m = float(data_a[ib].m); \
|
const float d = float(data_a[ib].d);
|
||||||
const uint uint_qh = data_a[ib].qh; \
|
const float m = float(data_a[ib].m);
|
||||||
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
|
const uint uint_qh = data_a[ib].qh;
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]); \
|
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
v = v*d + m;
|
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shader_q8_0_dequant_func = """
|
shader_q8_0_dequant_func = """
|
||||||
#define DEQUANT_FUNC const float d = float(data_a[ib].d); \
|
vec2 dequantize(uint ib, uint iqs) {
|
||||||
vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \
|
const float d = float(data_a[ib].d);
|
||||||
v = v * d;
|
return vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# MULMAT
|
# MULMAT
|
||||||
|
@ -218,6 +225,7 @@ mulmat_head = """#version 450
|
||||||
#extension GL_EXT_buffer_reference2 : require
|
#extension GL_EXT_buffer_reference2 : require
|
||||||
#extension GL_EXT_nonuniform_qualifier : require
|
#extension GL_EXT_nonuniform_qualifier : require
|
||||||
#extension GL_EXT_scalar_block_layout : require
|
#extension GL_EXT_scalar_block_layout : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
|
||||||
#define EXPERT_COUNT 8
|
#define EXPERT_COUNT 8
|
||||||
#endif
|
#endif
|
||||||
|
@ -233,21 +241,12 @@ mulmat_head = """#version 450
|
||||||
mulmat_body1 = """
|
mulmat_body1 = """
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
|
||||||
layout (binding = 0) readonly buffer IDS {int data_ids[];};
|
|
||||||
#else
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
#endif
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
layout (buffer_reference) readonly buffer InputA {A_TYPE data[];};
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
layout (binding = 3) readonly buffer tensor_input_a {InputA data_a_ptr[EXPERT_COUNT];};
|
|
||||||
|
|
||||||
#define DATA_A data_a.data
|
|
||||||
#else
|
|
||||||
#define DATA_A data_a
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
|
@ -268,12 +267,21 @@ layout (push_constant) uniform parameter
|
||||||
uint batch_stride_a;
|
uint batch_stride_a;
|
||||||
uint batch_stride_b;
|
uint batch_stride_b;
|
||||||
uint batch_stride_d;
|
uint batch_stride_d;
|
||||||
uint expert_stride_b;
|
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
uint expert_stride_a;
|
||||||
|
uint expert_stride_b0;
|
||||||
|
uint expert_stride_b1;
|
||||||
uint expert_stride_d;
|
uint expert_stride_d;
|
||||||
|
|
||||||
uint idx;
|
uint ids_stride;
|
||||||
uint nbi1;
|
|
||||||
uint n_as;
|
uint n_as;
|
||||||
|
uint nei0;
|
||||||
|
uint nei1;
|
||||||
|
uint nbi1;
|
||||||
|
uint ne11;
|
||||||
|
#endif
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
layout (constant_id = 1) const uint BM = 64;
|
layout (constant_id = 1) const uint BM = 64;
|
||||||
|
@ -289,9 +297,17 @@ layout (constant_id = 9) const uint WARP = 32;
|
||||||
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
shared u8vec2 rowids[2048];
|
||||||
|
#endif
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
const uint batch_idx = gl_GlobalInvocationID.z / p.n_as;
|
const uint batch_idx = gl_GlobalInvocationID.z / p.n_as;
|
||||||
const uint expert_idx = gl_GlobalInvocationID.z % p.n_as;
|
const uint expert_idx = gl_GlobalInvocationID.z % p.n_as;
|
||||||
|
#else
|
||||||
|
const uint batch_idx = gl_GlobalInvocationID.z;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -327,15 +343,34 @@ void main() {
|
||||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
const uint expert_id = data_ids[p.idx + expert_idx * p.nbi1];
|
uint _ne1 = 0;
|
||||||
InputA data_a = data_a_ptr[expert_id];
|
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||||
|
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||||
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||||
|
rowids[_ne1] = u8vec2(ii0, ii1);
|
||||||
|
_ne1++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const u8vec2 id = rowids[ir * BN + ic];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const uint start_k = ik * p.k_split;
|
const uint start_k = ik * p.k_split;
|
||||||
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||||
|
|
||||||
uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
uint pos_a = (
|
||||||
uint pos_b = (expert_idx * p.expert_stride_b + batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||||
|
uint pos_b = (
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
id.y * p.expert_stride_b1 +
|
||||||
|
(id.x % p.ne11) * p.expert_stride_b0 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b +
|
||||||
|
ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
FLOAT_TYPE cache_a[WMITER * TM];
|
FLOAT_TYPE cache_a[WMITER * TM];
|
||||||
|
@ -352,24 +387,24 @@ mulmat_load_scalar = """
|
||||||
#if LOAD_VEC_A == 8
|
#if LOAD_VEC_A == 8
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(DATA_A[idx][0].x);
|
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(DATA_A[idx][0].y);
|
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE(DATA_A[idx][0].z);
|
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE(DATA_A[idx][0].w);
|
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
||||||
buf_a[buf_idx + 4] = FLOAT_TYPE(DATA_A[idx][1].x);
|
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
||||||
buf_a[buf_idx + 5] = FLOAT_TYPE(DATA_A[idx][1].y);
|
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
||||||
buf_a[buf_idx + 6] = FLOAT_TYPE(DATA_A[idx][1].z);
|
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||||
buf_a[buf_idx + 7] = FLOAT_TYPE(DATA_A[idx][1].w);
|
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||||
#elif LOAD_VEC_A == 4
|
#elif LOAD_VEC_A == 4
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(DATA_A[idx].x);
|
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(DATA_A[idx].y);
|
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE(DATA_A[idx].z);
|
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE(DATA_A[idx].w);
|
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||||
#else
|
#else
|
||||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(DATA_A[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||||
} else {
|
} else {
|
||||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
|
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
|
@ -383,8 +418,8 @@ mulmat_load_q4_0 = """
|
||||||
const uint ib = idx / 16;
|
const uint ib = idx / 16;
|
||||||
const uint iqs = idx & 0xF;
|
const uint iqs = idx & 0xF;
|
||||||
|
|
||||||
const float d = float(DATA_A[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint vui = uint(DATA_A[ib].qs[iqs]);
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
@ -397,9 +432,9 @@ mulmat_load_q4_1 = """
|
||||||
const uint ib = idx / 16;
|
const uint ib = idx / 16;
|
||||||
const uint iqs = idx & 0xF;
|
const uint iqs = idx & 0xF;
|
||||||
|
|
||||||
const float d = float(DATA_A[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const float m = float(DATA_A[ib].m);
|
const float m = float(data_a[ib].m);
|
||||||
const uint vui = uint(DATA_A[ib].qs[iqs]);
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
|
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
@ -412,10 +447,10 @@ mulmat_load_q5_0 = """
|
||||||
const uint ib = idx / 16;
|
const uint ib = idx / 16;
|
||||||
const uint iqs = idx & 0xF;
|
const uint iqs = idx & 0xF;
|
||||||
|
|
||||||
const float d = float(DATA_A[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint uint_qh = uint(DATA_A[ib].qh[1]) << 16 | DATA_A[ib].qh[0];
|
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
|
||||||
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
const uint vui = uint(DATA_A[ib].qs[iqs]);
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
@ -428,11 +463,11 @@ mulmat_load_q5_1 = """
|
||||||
const uint ib = idx / 16;
|
const uint ib = idx / 16;
|
||||||
const uint iqs = idx & 0xF;
|
const uint iqs = idx & 0xF;
|
||||||
|
|
||||||
const float d = float(DATA_A[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const float m = float(DATA_A[ib].m);
|
const float m = float(data_a[ib].m);
|
||||||
const uint uint_qh = DATA_A[ib].qh;
|
const uint uint_qh = data_a[ib].qh;
|
||||||
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
const uint vui = uint(DATA_A[ib].qs[iqs]);
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
@ -445,8 +480,8 @@ mulmat_load_q8_0 = """
|
||||||
const uint ib = idx / 16;
|
const uint ib = idx / 16;
|
||||||
const uint iqs = (idx & 0xF) * 2;
|
const uint iqs = (idx & 0xF) * 2;
|
||||||
|
|
||||||
const float d = float(DATA_A[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const vec2 v = vec2(int(DATA_A[ib].qs[iqs]), int(DATA_A[ib].qs[iqs + 1])) * d;
|
const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
|
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);"""
|
||||||
|
@ -463,9 +498,9 @@ mulmat_load_q2_K = """
|
||||||
const uint scalesi = iqs / 8; // 0..15
|
const uint scalesi = iqs / 8; // 0..15
|
||||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||||
|
|
||||||
const uvec2 qs = uvec2(DATA_A[ib].qs[qsi], DATA_A[ib].qs[qsi + 1]);
|
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||||
const uint scales = DATA_A[ib].scales[scalesi];
|
const uint scales = data_a[ib].scales[scalesi];
|
||||||
const vec2 d = vec2(DATA_A[ib].d);
|
const vec2 d = vec2(data_a[ib].d);
|
||||||
|
|
||||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||||
|
|
||||||
|
@ -488,14 +523,14 @@ mulmat_load_q3_K = """
|
||||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||||
|
|
||||||
const int8_t us = int8_t(is < 4 ? (DATA_A[ib].scales[is-0] & 0xF) | (((DATA_A[ib].scales[is+8] >> 0) & 3) << 4) :
|
const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
|
||||||
is < 8 ? (DATA_A[ib].scales[is-0] & 0xF) | (((DATA_A[ib].scales[is+4] >> 2) & 3) << 4) :
|
is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
|
||||||
is < 12 ? (DATA_A[ib].scales[is-8] >> 4) | (((DATA_A[ib].scales[is+0] >> 4) & 3) << 4) :
|
is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
|
||||||
(DATA_A[ib].scales[is-8] >> 4) | (((DATA_A[ib].scales[is-4] >> 6) & 3) << 4));
|
(data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
|
||||||
const float dl = float(DATA_A[ib].d) * float(us - 32);
|
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((DATA_A[ib].qs[qsi ] >> qsshift) & 3) - (((DATA_A[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((DATA_A[ib].qs[qsi + 1] >> qsshift) & 3) - (((DATA_A[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));"""
|
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));"""
|
||||||
|
|
||||||
mulmat_load_q4_K = """
|
mulmat_load_q4_K = """
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
@ -509,22 +544,22 @@ mulmat_load_q4_K = """
|
||||||
const uint is = 2 * n + b; // 0..7
|
const uint is = 2 * n + b; // 0..7
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||||
|
|
||||||
const vec2 loadd = vec2(DATA_A[ib].d);
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
uint8_t sc;
|
uint8_t sc;
|
||||||
uint8_t mbyte;
|
uint8_t mbyte;
|
||||||
if (is < 4) {
|
if (is < 4) {
|
||||||
sc = uint8_t(DATA_A[ib].scales[is ] & 63);
|
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
||||||
mbyte = uint8_t(DATA_A[ib].scales[is + 4] & 63);
|
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
||||||
} else {
|
} else {
|
||||||
sc = uint8_t((DATA_A[ib].scales[is + 4] & 0xF) | ((DATA_A[ib].scales[is - 4] >> 6) << 4));
|
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
||||||
mbyte = uint8_t((DATA_A[ib].scales[is + 4] >> 4) | ((DATA_A[ib].scales[is ] >> 6) << 4));
|
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
||||||
}
|
}
|
||||||
const float d = loadd.x * sc;
|
const float d = loadd.x * sc;
|
||||||
const float m = loadd.y * mbyte;
|
const float m = loadd.y * mbyte;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(d * float((DATA_A[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
|
buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((DATA_A[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);"""
|
buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);"""
|
||||||
|
|
||||||
mulmat_load_q5_K = """
|
mulmat_load_q5_K = """
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
@ -541,22 +576,22 @@ mulmat_load_q5_K = """
|
||||||
|
|
||||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||||
|
|
||||||
const vec2 loadd = vec2(DATA_A[ib].d);
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
uint8_t sc;
|
uint8_t sc;
|
||||||
uint8_t mbyte;
|
uint8_t mbyte;
|
||||||
if (is < 4) {
|
if (is < 4) {
|
||||||
sc = uint8_t(DATA_A[ib].scales[is ] & 63);
|
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
||||||
mbyte = uint8_t(DATA_A[ib].scales[is + 4] & 63);
|
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
||||||
} else {
|
} else {
|
||||||
sc = uint8_t((DATA_A[ib].scales[is + 4] & 0xF) | ((DATA_A[ib].scales[is - 4] >> 6) << 4));
|
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
||||||
mbyte = uint8_t((DATA_A[ib].scales[is + 4] >> 4) | ((DATA_A[ib].scales[is ] >> 6) << 4));
|
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
||||||
}
|
}
|
||||||
const float d = loadd.x * sc;
|
const float d = loadd.x * sc;
|
||||||
const float m = loadd.y * mbyte;
|
const float m = loadd.y * mbyte;
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(d * (float((DATA_A[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((DATA_A[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
|
buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((DATA_A[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((DATA_A[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);"""
|
buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);"""
|
||||||
|
|
||||||
mulmat_load_q6_K = """
|
mulmat_load_q6_K = """
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
@ -573,10 +608,10 @@ mulmat_load_q6_K = """
|
||||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||||
|
|
||||||
const float dscale = float(DATA_A[ib].d) * float(DATA_A[ib].scales[is]);
|
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((DATA_A[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((DATA_A[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((DATA_A[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((DATA_A[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));"""
|
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));"""
|
||||||
|
|
||||||
mulmat_body2 = """
|
mulmat_body2 = """
|
||||||
}
|
}
|
||||||
|
@ -643,7 +678,11 @@ mulmat_body2 = """
|
||||||
const uint dr = ir * BM + warp_r * WM;
|
const uint dr = ir * BM + warp_r * WM;
|
||||||
const uint dc = ic * BN + warp_c * WN;
|
const uint dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
const uint offsets = expert_idx * p.expert_stride_d + batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
const uint offsets =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx * p.expert_stride_d +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||||
|
|
||||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
@ -1109,6 +1148,20 @@ mul_mat_vec_head = """#version 450
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
#extension GL_EXT_shader_8bit_storage : require
|
#extension GL_EXT_shader_8bit_storage : require
|
||||||
|
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
#define EXPERT_COUNT 8
|
||||||
|
#endif
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
mul_mat_vec_layout = """
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
|
#endif
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint ncols;
|
uint ncols;
|
||||||
|
@ -1124,16 +1177,25 @@ layout (push_constant) uniform parameter
|
||||||
uint batch_stride_a;
|
uint batch_stride_a;
|
||||||
uint batch_stride_b;
|
uint batch_stride_b;
|
||||||
uint batch_stride_d;
|
uint batch_stride_d;
|
||||||
|
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
uint expert_stride_a;
|
||||||
|
uint expert_stride_b0;
|
||||||
|
uint expert_stride_b1;
|
||||||
|
uint expert_stride_d0;
|
||||||
|
uint expert_stride_d1;
|
||||||
|
|
||||||
|
uint ne11;
|
||||||
|
uint nei0;
|
||||||
|
uint nbi1;
|
||||||
|
uint n_as;
|
||||||
|
#endif
|
||||||
} p;
|
} p;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mul_mat_vec_body = """
|
mul_mat_vec_body = """
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||||
|
@ -1142,6 +1204,10 @@ void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1151,9 +1217,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||||
|
|
||||||
|
@ -1165,7 +1249,7 @@ void main() {
|
||||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||||
const uint iybs = col - col%QUANT_K; // y block start index
|
const uint iybs = col - col%QUANT_K; // y block start index
|
||||||
|
|
||||||
DEQUANT_FUNC
|
vec2 v = dequantize(ib, iqs);
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
|
tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
|
||||||
|
@ -1181,7 +1265,7 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
@ -1190,15 +1274,15 @@ void main() {
|
||||||
mul_mat_vec_q2_K_body = """
|
mul_mat_vec_q2_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1208,9 +1292,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
@ -1268,22 +1370,22 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
mul_mat_vec_q3_K_body = """
|
mul_mat_vec_q3_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1293,9 +1395,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
@ -1346,22 +1466,22 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
mul_mat_vec_q4_K_body = """
|
mul_mat_vec_q4_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1371,9 +1491,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
@ -1473,22 +1611,22 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
mul_mat_vec_q5_K_body = """
|
mul_mat_vec_q5_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1498,9 +1636,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
@ -1596,22 +1752,22 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
mul_mat_vec_q6_K_body = """
|
mul_mat_vec_q6_K_body = """
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x;
|
const uint row = gl_WorkGroupID.x;
|
||||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0;
|
||||||
|
const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint i13 = batch_idx / p.ne12;
|
const uint i13 = batch_idx / p.ne12;
|
||||||
const uint i12 = batch_idx % p.ne12;
|
const uint i12 = batch_idx % p.ne12;
|
||||||
|
@ -1621,9 +1777,27 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint a_offset = batch_idx_a * p.batch_stride_a;
|
#ifdef MUL_MAT_ID
|
||||||
const uint b_offset = batch_idx * p.batch_stride_b;
|
const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0];
|
||||||
const uint d_offset = batch_idx * p.batch_stride_d;
|
#endif
|
||||||
|
|
||||||
|
const uint a_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_id * p.expert_stride_a +
|
||||||
|
#endif
|
||||||
|
batch_idx_a * p.batch_stride_a;
|
||||||
|
const uint b_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
(expert_idx0 % p.ne11) * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_b;
|
||||||
|
const uint d_offset =
|
||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
expert_idx0 * p.expert_stride_b0 +
|
||||||
|
expert_idx1 * p.expert_stride_b1 +
|
||||||
|
#endif
|
||||||
|
batch_idx * p.batch_stride_d;
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
@ -1687,7 +1861,7 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[d_offset + row] = D_TYPE(tmp[0]);
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
@ -1868,12 +2042,13 @@ layout (push_constant) uniform parameter
|
||||||
float param1; float param2;
|
float param1; float param2;
|
||||||
} p;"""
|
} p;"""
|
||||||
|
|
||||||
generic_unary_op_funcs = """
|
generic_unary_op_layout = """
|
||||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};"""
|
||||||
|
|
||||||
|
generic_unary_op_funcs = """
|
||||||
uint src0_idx(uint idx) {
|
uint src0_idx(uint idx) {
|
||||||
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
|
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
|
||||||
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||||
|
@ -1901,7 +2076,7 @@ void main() {
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_funcs}\n{generic_unary_op_main}"
|
generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}"
|
||||||
|
|
||||||
generic_binary_op_head = """#version 450
|
generic_binary_op_head = """#version 450
|
||||||
|
|
||||||
|
@ -1917,13 +2092,14 @@ layout (push_constant) uniform parameter
|
||||||
float param1; float param2;
|
float param1; float param2;
|
||||||
} p;"""
|
} p;"""
|
||||||
|
|
||||||
generic_binary_op_funcs = """
|
generic_binary_op_layout = """
|
||||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};"""
|
||||||
|
|
||||||
|
generic_binary_op_funcs = """
|
||||||
uint src0_idx(uint idx) {
|
uint src0_idx(uint idx) {
|
||||||
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
|
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
|
||||||
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||||
|
@ -1962,7 +2138,7 @@ void main() {
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
|
generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
|
||||||
|
|
||||||
# MUL F32
|
# MUL F32
|
||||||
mul_body = """
|
mul_body = """
|
||||||
|
@ -2053,7 +2229,7 @@ void main() {
|
||||||
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
||||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||||
|
|
||||||
DEQUANT_FUNC
|
vec2 v = dequantize(ib, iqs);
|
||||||
|
|
||||||
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
||||||
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
||||||
|
@ -2792,31 +2968,32 @@ async def main():
|
||||||
stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
|
stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
|
||||||
|
|
||||||
if i == GGML_TYPE_F16:
|
if i == GGML_TYPE_F16:
|
||||||
stream.extend((shader_f16_defines, shader_float_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q4_0:
|
elif i == GGML_TYPE_Q4_0:
|
||||||
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q4_1:
|
elif i == GGML_TYPE_Q4_1:
|
||||||
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q5_0:
|
elif i == GGML_TYPE_Q5_0:
|
||||||
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q5_1:
|
elif i == GGML_TYPE_Q5_1:
|
||||||
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q8_0:
|
elif i == GGML_TYPE_Q8_0:
|
||||||
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body))
|
stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body))
|
||||||
elif i == GGML_TYPE_Q2_K:
|
elif i == GGML_TYPE_Q2_K:
|
||||||
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
|
stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body))
|
||||||
elif i == GGML_TYPE_Q3_K:
|
elif i == GGML_TYPE_Q3_K:
|
||||||
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
|
stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body))
|
||||||
elif i == GGML_TYPE_Q4_K:
|
elif i == GGML_TYPE_Q4_K:
|
||||||
stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body))
|
stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body))
|
||||||
elif i == GGML_TYPE_Q5_K:
|
elif i == GGML_TYPE_Q5_K:
|
||||||
stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body))
|
stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body))
|
||||||
elif i == GGML_TYPE_Q6_K:
|
elif i == GGML_TYPE_Q6_K:
|
||||||
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
|
stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
|
tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
|
||||||
|
tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
|
||||||
|
|
||||||
# Dequant shaders
|
# Dequant shaders
|
||||||
for i in range(0, VK_NUM_TYPES):
|
for i in range(0, VK_NUM_TYPES):
|
||||||
|
@ -2858,20 +3035,20 @@ async def main():
|
||||||
optimization_workaround = False
|
optimization_workaround = False
|
||||||
|
|
||||||
if i == GGML_TYPE_F32:
|
if i == GGML_TYPE_F32:
|
||||||
stream.extend((shader_f32_defines, generic_binary_op_funcs, get_rows_float_body))
|
stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
|
||||||
elif i == GGML_TYPE_F16:
|
elif i == GGML_TYPE_F16:
|
||||||
stream.extend((shader_f16_defines, generic_binary_op_funcs, get_rows_float_body))
|
stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body))
|
||||||
optimization_workaround = True
|
optimization_workaround = True
|
||||||
elif i == GGML_TYPE_Q4_0:
|
elif i == GGML_TYPE_Q4_0:
|
||||||
stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
||||||
elif i == GGML_TYPE_Q4_1:
|
elif i == GGML_TYPE_Q4_1:
|
||||||
stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body))
|
stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body))
|
||||||
elif i == GGML_TYPE_Q5_0:
|
elif i == GGML_TYPE_Q5_0:
|
||||||
stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
||||||
elif i == GGML_TYPE_Q5_1:
|
elif i == GGML_TYPE_Q5_1:
|
||||||
stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body))
|
stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body))
|
||||||
elif i == GGML_TYPE_Q8_0:
|
elif i == GGML_TYPE_Q8_0:
|
||||||
stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue