diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ef067356e..53f9168d2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3929,13 +3929,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); - - if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32; - } - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32_f16; + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src1 == nullptr || src1->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_f32; + } + else if (src1->type == GGML_TYPE_F16) { + return ctx->device->pipeline_soft_max_f32_f16; + } } return nullptr; case GGML_OP_ROPE: