From 5da73f8085e9b3276ec7dba2c30c8ad775b7bcdd Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 27 Jul 2024 12:52:59 +0800 Subject: [PATCH] refactoring: move forward and supports_op into ops file --- ggml/src/ggml-qnn.cpp | 55 +------------- ggml/src/ggml-qnn/backend-ops.cpp | 115 ++++++++++++++++++++++++------ ggml/src/ggml-qnn/backend-ops.hpp | 13 +--- 3 files changed, 97 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-qnn.cpp b/ggml/src/ggml-qnn.cpp index 6472d3e15..22b57b175 100644 --- a/ggml/src/ggml-qnn.cpp +++ b/ggml/src/ggml-qnn.cpp @@ -114,23 +114,7 @@ struct ggml_backend_qnn_buffer_type_context { // // ================================================================================================= static bool ggml_qnn_compute_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) { - size_t unary_op_idx = tensor->op; - if (tensor->op == GGML_OP_UNARY) { - unary_op_idx = qnn::kGgmlUnaryOpStart + ggml_get_unary_op(tensor); - } - - auto unary_op = qnn::ggml_qnn_unary_op_array()[unary_op_idx]; - if (unary_op) { - return unary_op(ctx, tensor->src[0], tensor); - } - - auto binary_op = qnn::ggml_qnn_binary_op_array()[tensor->op]; - if (binary_op) { - return binary_op(ctx, tensor->src[0], tensor->src[1], tensor); - } - - QNN_LOG_WARN("unsupported op %d", tensor->op); - return false; + return qnn::ggml_qnn_forward(ctx, tensor); } static const char *ggml_backend_qnn_buffer_get_name(ggml_backend_buffer_t buffer) { @@ -288,42 +272,7 @@ GGML_CALL static ggml_status ggml_backend_qnn_graph_compute(ggml_backend_t backe GGML_CALL static bool ggml_backend_qnn_supports_op(ggml_backend_t backend, const ggml_tensor *op) { GGML_UNUSED(backend); - - if (op->op == GGML_OP_UNARY) { - if (!qnn::ggml_qnn_unary_op_array()[qnn::kGgmlUnaryOpStart + ggml_get_unary_op(op)]) { - QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op)); - return false; - } - - if (!op->src[0]) { - QNN_LOG_DEBUG("src0 is nullptr"); - return false; - } - } else if (op->op != GGML_OP_NONE) { - if (!qnn::ggml_qnn_unary_op_array()[op->op] && !qnn::ggml_qnn_binary_op_array()[op->op]) { - QNN_LOG_DEBUG("unsupported op %d", op->op); - return false; - } - - if (!op->src[0] || !op->src[1]) { - QNN_LOG_DEBUG("src0 or src1 is nullptr"); - return false; - } - } - - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_I8: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - break; - default: - QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type); - return false; - } - - return true; + return qnn::ggml_qnn_supports_op(op); } GGML_CALL static bool ggml_backend_qnn_offload_op(ggml_backend_t backend, const ggml_tensor *op) { diff --git a/ggml/src/ggml-qnn/backend-ops.cpp b/ggml/src/ggml-qnn/backend-ops.cpp index 1f8b75e5e..20a4178fd 100644 --- a/ggml/src/ggml-qnn/backend-ops.cpp +++ b/ggml/src/ggml-qnn/backend-ops.cpp @@ -56,6 +56,15 @@ bool qnn_is_valid_params(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, namespace { +typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst); +typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1, + ggml_tensor *dst); + +typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT]; +typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT]; + +constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT; + void print_ggml_tensor(const ggml_tensor *tensor) { QNN_LOG_DEBUG("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi)\n", tensor->name, tensor->type, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2], @@ -106,8 +115,8 @@ qnn::ggml_qnn_graph *get_qnn_graph_from_cache(ggml_backend_qnn_context *ctx, siz GGML_ASSERT(op < (GGML_OP_COUNT + GGML_UNARY_OP_COUNT)); auto &graph_cache = ctx->qnn_graph_cache; - const auto *op_name = op < qnn::kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op)) - : ggml_unary_op_name(ggml_unary_op(op - qnn::kGgmlUnaryOpStart)); + const auto *op_name = + op < kGgmlUnaryOpStart ? ggml_op_name(ggml_op(op)) : ggml_unary_op_name(ggml_unary_op(op - kGgmlUnaryOpStart)); auto graph_key = get_graph_key<_InputSize, _OutputSize>(op_name, inputs, outputs); auto it = graph_cache.find(graph_key); qnn::ggml_qnn_graph *graph_ptr = nullptr; @@ -237,7 +246,7 @@ constexpr const char *kGgmlOpToQnnOp[] = { static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT), "GGML_OP_COUNT does not match the size of the kGgmlOpToQnnOp table"); -static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + qnn::kGgmlUnaryOpStart] != nullptr, +static_assert(kGgmlOpToQnnOp[GGML_UNARY_OP_GELU + kGgmlUnaryOpStart] != nullptr, "GGML_UNARY_OP_GELU does not correspond to QNN_OP_GELU"); template @@ -281,10 +290,8 @@ bool qnn_unary_op_impl(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_ten return succeed; } -} // namespace - -qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() { - static constexpr const qnn::ggml_qnn_unary_op_t kQnnOpsTable[] = { +ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array() { + static constexpr const ggml_qnn_unary_op_t kQnnOpsTable[] = { nullptr, // GGML_OP_NONE nullptr, // GGML_OP_DUP nullptr, // GGML_OP_ADD @@ -369,19 +376,19 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() { nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK // ggml_unary_op - nullptr, // GGML_UNARY_OP_ABS - nullptr, // GGML_UNARY_OP_SGN - nullptr, // GGML_UNARY_OP_NEG - nullptr, // GGML_UNARY_OP_STEP - nullptr, // GGML_UNARY_OP_TANH - nullptr, // GGML_UNARY_OP_ELU - nullptr, // GGML_UNARY_OP_RELU - nullptr, // GGML_UNARY_OP_SIGMOID - qnn_unary_op_impl, // GGML_UNARY_OP_GELU - nullptr, // GGML_UNARY_OP_GELU_QUICK - nullptr, // GGML_UNARY_OP_SILU - nullptr, // GGML_UNARY_OP_HARDSWISH - nullptr, // GGML_UNARY_OP_HARDSIGMOID + nullptr, // GGML_UNARY_OP_ABS + nullptr, // GGML_UNARY_OP_SGN + nullptr, // GGML_UNARY_OP_NEG + nullptr, // GGML_UNARY_OP_STEP + nullptr, // GGML_UNARY_OP_TANH + nullptr, // GGML_UNARY_OP_ELU + nullptr, // GGML_UNARY_OP_RELU + nullptr, // GGML_UNARY_OP_SIGMOID + qnn_unary_op_impl, // GGML_UNARY_OP_GELU + nullptr, // GGML_UNARY_OP_GELU_QUICK + nullptr, // GGML_UNARY_OP_SILU + nullptr, // GGML_UNARY_OP_HARDSWISH + nullptr, // GGML_UNARY_OP_HARDSIGMOID }; static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == (GGML_OP_COUNT + GGML_UNARY_OP_COUNT), @@ -389,8 +396,8 @@ qnn::ggml_qnn_unary_op_array_t qnn::ggml_qnn_unary_op_array() { return kQnnOpsTable; } -qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() { - static constexpr const qnn::ggml_qnn_binary_op_t kQnnOpsTable[] = { +ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array() { + static constexpr const ggml_qnn_binary_op_t kQnnOpsTable[] = { nullptr, // GGML_OP_NONE nullptr, // GGML_OP_DUP qnn_binary_op_impl, // GGML_OP_ADD @@ -479,3 +486,67 @@ qnn::ggml_qnn_binary_op_array_t qnn::ggml_qnn_binary_op_array() { "GGML_OP_COUNT does not match the size of the ops table"); return kQnnOpsTable; } + +} // namespace + +namespace qnn { + +bool ggml_qnn_supports_op(const ggml_tensor *op) { + if (op->op == GGML_OP_UNARY) { + if (!ggml_qnn_unary_op_array()[kGgmlUnaryOpStart + ggml_get_unary_op(op)]) { + QNN_LOG_DEBUG("unsupported unary op %d", ggml_get_unary_op(op)); + return false; + } + + if (!op->src[0]) { + QNN_LOG_DEBUG("src0 is nullptr"); + return false; + } + } else if (op->op != GGML_OP_NONE) { + if (!ggml_qnn_unary_op_array()[op->op] && !ggml_qnn_binary_op_array()[op->op]) { + QNN_LOG_DEBUG("unsupported op %d", op->op); + return false; + } + + if (!op->src[0] || !op->src[1]) { + QNN_LOG_DEBUG("src0 or src1 is nullptr"); + return false; + } + } + + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_I8: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + break; + default: + QNN_LOG_DEBUG("unsupported src0 type %d", op->src[0]->type); + return false; + } + + return true; +} + +bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor) { + size_t unary_op_idx = tensor->op; + if (tensor->op == GGML_OP_UNARY) { + unary_op_idx = kGgmlUnaryOpStart + ggml_get_unary_op(tensor); + } + + auto unary_op = ggml_qnn_unary_op_array()[unary_op_idx]; + if (unary_op) { + return unary_op(ctx, tensor->src[0], tensor); + } + + auto binary_op = ggml_qnn_binary_op_array()[tensor->op]; + if (binary_op) { + return binary_op(ctx, tensor->src[0], tensor->src[1], tensor); + } + + QNN_LOG_WARN("unsupported op %d", tensor->op); + return false; +} + +} // namespace qnn diff --git a/ggml/src/ggml-qnn/backend-ops.hpp b/ggml/src/ggml-qnn/backend-ops.hpp index 614bcf651..ed4ce994f 100644 --- a/ggml/src/ggml-qnn/backend-ops.hpp +++ b/ggml/src/ggml-qnn/backend-ops.hpp @@ -6,16 +6,7 @@ namespace qnn { -typedef bool (*ggml_qnn_unary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src, ggml_tensor *dst); -typedef bool (*ggml_qnn_binary_op_t)(ggml_backend_qnn_context *ctx, ggml_tensor *src0, ggml_tensor *src1, - ggml_tensor *dst); - -typedef const ggml_qnn_unary_op_t (&ggml_qnn_unary_op_array_t)[GGML_OP_COUNT + GGML_UNARY_OP_COUNT]; -typedef const ggml_qnn_binary_op_t (&ggml_qnn_binary_op_array_t)[GGML_OP_COUNT]; - -constexpr const size_t kGgmlUnaryOpStart = GGML_OP_COUNT; - -ggml_qnn_unary_op_array_t ggml_qnn_unary_op_array(); -ggml_qnn_binary_op_array_t ggml_qnn_binary_op_array(); +bool ggml_qnn_supports_op(const ggml_tensor *op); +bool ggml_qnn_forward(ggml_backend_qnn_context *ctx, struct ggml_tensor *tensor); } // namespace qnn