vulkan: support GGML_OP_SUM
This commit is contained in:
parent
5c1d8a946f
commit
abf4c2ef74
1 changed files with 19 additions and 0 deletions
|
@ -5276,6 +5276,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_argsort_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sum_rows_f32;
|
||||
|
@ -5554,6 +5555,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
elements = { nr, 1, 1 };
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
// We use GGML_OP_SUM_ROWS with 1 row.
|
||||
elements = { 1, 1, 1 };
|
||||
break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
const uint32_t num_groups = dst->op_params[0];
|
||||
|
@ -6136,6 +6141,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
@ -7029,6 +7038,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -7080,6 +7090,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -7200,6 +7211,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_ARGSORT:
|
||||
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SUM:
|
||||
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
@ -7314,6 +7329,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -8248,6 +8264,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
@ -8819,6 +8836,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
|
||||
} else if (tensor->op == GGML_OP_ARGSORT) {
|
||||
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_SUM) {
|
||||
tensor_clone = ggml_sum(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_IM2COL) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue