diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7b13cc10b..b3f6177dd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -218,6 +218,7 @@ struct vk_device_struct { vk_pipeline pipeline_tanh_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_argsort_f32; @@ -1498,7 +1499,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -3933,10 +3936,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32; + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32_f16; + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; } return nullptr; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index a83b9a405..f9727679e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -152,30 +152,21 @@ void main() { // instantiate the soft_max function for several different // dimensions, to allow loop unrolling uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; - switch (num_blocks) { - case 1: - soft_max(1); - break; - case 2: - soft_max(2); - break; - case 3: - soft_max(3); - break; - case 4: - soft_max(4); - break; - case 5: - case 6: - case 7: - case 8: - soft_max(8); - break; - case 16: - soft_max(16); - break; - default: + if (num_blocks > 32) { soft_max(num_blocks); - break; + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); } }