diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0c685c09c..21b7a1f9b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5276,6 +5276,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_argsort_f32; } return nullptr; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; @@ -5554,6 +5555,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; case GGML_OP_GROUP_NORM: { const uint32_t num_groups = dst->op_params[0]; @@ -6136,6 +6141,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } @@ -7029,6 +7038,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -7080,6 +7090,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -7200,6 +7211,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ARGSORT: ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); @@ -7314,6 +7329,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_TRANSPOSE: case GGML_OP_NONE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -8248,6 +8264,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -8819,6 +8836,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_ARGSORT) { tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_IM2COL) {