vulkan: support GGML_OP_SUM

This commit is contained in:
Rémy O 2025-02-08 10:47:05 +01:00
parent 5c1d8a946f
commit abf4c2ef74

View file

@ -5276,6 +5276,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_argsort_f32; return ctx->device->pipeline_argsort_f32;
} }
return nullptr; return nullptr;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sum_rows_f32; return ctx->device->pipeline_sum_rows_f32;
@ -5554,6 +5555,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { nr, 1, 1 }; elements = { nr, 1, 1 };
} }
} break; } break;
case GGML_OP_SUM:
// We use GGML_OP_SUM_ROWS with 1 row.
elements = { 1, 1, 1 };
break;
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
{ {
const uint32_t num_groups = dst->op_params[0]; const uint32_t num_groups = dst->op_params[0];
@ -6136,6 +6141,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}, dryrun); }, dryrun);
} }
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
} }
@ -7029,6 +7038,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
@ -7080,6 +7090,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
@ -7200,6 +7211,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_SUM:
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
break; break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
@ -7314,6 +7329,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
@ -8248,6 +8264,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
@ -8819,6 +8836,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_ARGSORT) { } else if (tensor->op == GGML_OP_ARGSORT) {
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_SUM) {
tensor_clone = ggml_sum(ggml_ctx, src0_clone);
} else if (tensor->op == GGML_OP_SUM_ROWS) { } else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
} else if (tensor->op == GGML_OP_IM2COL) { } else if (tensor->op == GGML_OP_IM2COL) {