From 8932135fdb8ee05a9e4c44a64137a28c35bf05bc Mon Sep 17 00:00:00 2001 From: hongruichen Date: Thu, 11 Jul 2024 00:07:00 +0800 Subject: [PATCH] add sqrt and mul ops --- ggml/src/ggml-qnn/backend-ops.cpp | 137 +++++++++++++++--------------- 1 file changed, 70 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-qnn/backend-ops.cpp b/ggml/src/ggml-qnn/backend-ops.cpp index 5871a7b6e..ca48f79bb 100644 --- a/ggml/src/ggml-qnn/backend-ops.cpp +++ b/ggml/src/ggml-qnn/backend-ops.cpp @@ -8,6 +8,18 @@ #include "tensor.hpp" #include "utils.hpp" +#ifndef NDEBUG +#define CHECK_PARAMS(ctx, src0, src1, dst) \ + do { \ + if (!qnn_is_valid_params((ctx), (src0), (src1), (dst))) { \ + return; \ + } \ + } while (0) + +#else +#define CHECK_PARAMS(ctx, src0, src1, dst) +#endif + namespace { void print_ggml_tensor(const ggml_tensor *tensor) { @@ -144,29 +156,29 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c } constexpr const char *kGgmlOpToQnnOp[] = { - nullptr, // GGML_OP_NONE - nullptr, // GGML_OP_DUP - QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD - nullptr, // GGML_OP_ADD1 - nullptr, // GGML_OP_ACC - nullptr, // GGML_OP_SUB - nullptr, // GGML_OP_MUL - nullptr, // GGML_OP_DIV - nullptr, // GGML_OP_SQR - nullptr, // GGML_OP_SQRT - nullptr, // GGML_OP_LOG - nullptr, // GGML_OP_SUM - nullptr, // GGML_OP_SUM_ROWS - nullptr, // GGML_OP_MEAN - nullptr, // GGML_OP_ARGMAX - nullptr, // GGML_OP_REPEAT - nullptr, // GGML_OP_REPEAT_BACK - nullptr, // GGML_OP_CONCAT - nullptr, // GGML_OP_SILU_BACK - nullptr, // GGML_OP_NORM - nullptr, // GGML_OP_RMS_NORM - nullptr, // GGML_OP_RMS_NORM_BACK - nullptr, // GGML_OP_GROUP_NORM + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + QNN_OP_ELEMENT_WISE_MULTIPLY, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + QNN_OP_ELEMENT_WISE_SQUARE_ROOT, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM QNN_OP_MAT_MUL, // GGML_OP_MUL_MAT nullptr, // GGML_OP_MUL_MAT_ID @@ -236,6 +248,8 @@ void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, ggml_tensor *dst) { static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP"); + CHECK_PARAMS(ctx, src0, src1, dst); + qnn::qnn_perf perf(ggml_op_name(_GgmlOp)); perf.start(); @@ -255,24 +269,16 @@ void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, } // namespace -#ifndef NDEBUG -#define CHECK_PARAMS(ctx, src0, src1, dst) \ - do { \ - if (!qnn_is_valid_params((ctx), (src0), (src1), (dst))) { \ - return; \ - } \ - } while (0) - -#else -#define CHECK_PARAMS(ctx, src0, src1, dst) -#endif - static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) { - CHECK_PARAMS(ctx, src0, src1, dst); qnn_binary_op_impl(ctx, src0, src1, dst); } +static void ggml_qnn_mul(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst) { + qnn_binary_op_impl(ctx, src0, src1, dst); +} + /* * ggml_qnn_mul_mat was re-added as a standalone function because * the following comments came from https://github.com/ggerganov/llama.cpp/pull/1632 @@ -286,7 +292,6 @@ static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, */ static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) { - CHECK_PARAMS(ctx, src0, src1, dst); qnn_binary_op_impl(ctx, src0, src1, dst); } @@ -329,6 +334,11 @@ static void ggml_qnn_leaky_relu(ggml_backend_qnn_context *ctx, const ggml_tensor static void ggml_qnn_sqr(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) {} +static void ggml_qnn_sqrt(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst) { + qnn_binary_op_impl(ctx, src0, src1, dst); +} + static void ggml_qnn_norm(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) {} @@ -392,38 +402,31 @@ static void ggml_qnn_argsort(ggml_backend_qnn_context *ctx, const ggml_tensor *s GGML_ASSERT(ggml_is_contiguous(src0)); } -static void ggml_qnn_nop(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst) { - (void)src0; - (void)src1; - (void)dst; -} - qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() { static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[] = { - nullptr, // GGML_OP_NONE - nullptr, // GGML_OP_DUP - ggml_qnn_add, // GGML_OP_ADD - nullptr, // GGML_OP_ADD1 - nullptr, // GGML_OP_ACC - nullptr, // GGML_OP_SUB - nullptr, // GGML_OP_MUL - nullptr, // GGML_OP_DIV - nullptr, // GGML_OP_SQR - nullptr, // GGML_OP_SQRT - nullptr, // GGML_OP_LOG - nullptr, // GGML_OP_SUM - nullptr, // GGML_OP_SUM_ROWS - nullptr, // GGML_OP_MEAN - nullptr, // GGML_OP_ARGMAX - nullptr, // GGML_OP_REPEAT - nullptr, // GGML_OP_REPEAT_BACK - nullptr, // GGML_OP_CONCAT - nullptr, // GGML_OP_SILU_BACK - nullptr, // GGML_OP_NORM - nullptr, // GGML_OP_RMS_NORM - nullptr, // GGML_OP_RMS_NORM_BACK - nullptr, // GGML_OP_GROUP_NORM + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + ggml_qnn_add, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + ggml_qnn_mul, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + ggml_qnn_sqrt, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM ggml_qnn_mul_mat, // GGML_OP_MUL_MAT nullptr, // GGML_OP_MUL_MAT_ID