add sqrt and mul ops

This commit is contained in:
hongruichen 2024-07-11 00:07:00 +08:00
parent 7ea28a6fac
commit 8932135fdb

View file

@ -8,6 +8,18 @@
#include "tensor.hpp" #include "tensor.hpp"
#include "utils.hpp" #include "utils.hpp"
#ifndef NDEBUG
#define CHECK_PARAMS(ctx, src0, src1, dst) \
do { \
if (!qnn_is_valid_params((ctx), (src0), (src1), (dst))) { \
return; \
} \
} while (0)
#else
#define CHECK_PARAMS(ctx, src0, src1, dst)
#endif
namespace { namespace {
void print_ggml_tensor(const ggml_tensor *tensor) { void print_ggml_tensor(const ggml_tensor *tensor) {
@ -144,29 +156,29 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c
} }
constexpr const char *kGgmlOpToQnnOp[] = { constexpr const char *kGgmlOpToQnnOp[] = {
nullptr, // GGML_OP_NONE nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP nullptr, // GGML_OP_DUP
QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1 nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB nullptr, // GGML_OP_SUB
nullptr, // GGML_OP_MUL QNN_OP_ELEMENT_WISE_MULTIPLY, // GGML_OP_MUL
nullptr, // GGML_OP_DIV nullptr, // GGML_OP_DIV
nullptr, // GGML_OP_SQR nullptr, // GGML_OP_SQR
nullptr, // GGML_OP_SQRT QNN_OP_ELEMENT_WISE_SQUARE_ROOT, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG nullptr, // GGML_OP_LOG
nullptr, // GGML_OP_SUM nullptr, // GGML_OP_SUM
nullptr, // GGML_OP_SUM_ROWS nullptr, // GGML_OP_SUM_ROWS
nullptr, // GGML_OP_MEAN nullptr, // GGML_OP_MEAN
nullptr, // GGML_OP_ARGMAX nullptr, // GGML_OP_ARGMAX
nullptr, // GGML_OP_REPEAT nullptr, // GGML_OP_REPEAT
nullptr, // GGML_OP_REPEAT_BACK nullptr, // GGML_OP_REPEAT_BACK
nullptr, // GGML_OP_CONCAT nullptr, // GGML_OP_CONCAT
nullptr, // GGML_OP_SILU_BACK nullptr, // GGML_OP_SILU_BACK
nullptr, // GGML_OP_NORM nullptr, // GGML_OP_NORM
nullptr, // GGML_OP_RMS_NORM nullptr, // GGML_OP_RMS_NORM
nullptr, // GGML_OP_RMS_NORM_BACK nullptr, // GGML_OP_RMS_NORM_BACK
nullptr, // GGML_OP_GROUP_NORM nullptr, // GGML_OP_GROUP_NORM
QNN_OP_MAT_MUL, // GGML_OP_MUL_MAT QNN_OP_MAT_MUL, // GGML_OP_MUL_MAT
nullptr, // GGML_OP_MUL_MAT_ID nullptr, // GGML_OP_MUL_MAT_ID
@ -236,6 +248,8 @@ void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
ggml_tensor *dst) { ggml_tensor *dst) {
static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP"); static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP");
CHECK_PARAMS(ctx, src0, src1, dst);
qnn::qnn_perf perf(ggml_op_name(_GgmlOp)); qnn::qnn_perf perf(ggml_op_name(_GgmlOp));
perf.start(); perf.start();
@ -255,24 +269,16 @@ void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
} // namespace } // namespace
#ifndef NDEBUG
#define CHECK_PARAMS(ctx, src0, src1, dst) \
do { \
if (!qnn_is_valid_params((ctx), (src0), (src1), (dst))) { \
return; \
} \
} while (0)
#else
#define CHECK_PARAMS(ctx, src0, src1, dst)
#endif
static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) { ggml_tensor *dst) {
CHECK_PARAMS(ctx, src0, src1, dst);
qnn_binary_op_impl<GGML_OP_ADD>(ctx, src0, src1, dst); qnn_binary_op_impl<GGML_OP_ADD>(ctx, src0, src1, dst);
} }
static void ggml_qnn_mul(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) {
qnn_binary_op_impl<GGML_OP_MUL>(ctx, src0, src1, dst);
}
/* /*
* ggml_qnn_mul_mat was re-added as a standalone function because * ggml_qnn_mul_mat was re-added as a standalone function because
* the following comments came from https://github.com/ggerganov/llama.cpp/pull/1632 * the following comments came from https://github.com/ggerganov/llama.cpp/pull/1632
@ -286,7 +292,6 @@ static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0,
*/ */
static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) { ggml_tensor *dst) {
CHECK_PARAMS(ctx, src0, src1, dst);
qnn_binary_op_impl<GGML_OP_MUL_MAT>(ctx, src0, src1, dst); qnn_binary_op_impl<GGML_OP_MUL_MAT>(ctx, src0, src1, dst);
} }
@ -329,6 +334,11 @@ static void ggml_qnn_leaky_relu(ggml_backend_qnn_context *ctx, const ggml_tensor
static void ggml_qnn_sqr(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, static void ggml_qnn_sqr(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) {} ggml_tensor *dst) {}
static void ggml_qnn_sqrt(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) {
qnn_binary_op_impl<GGML_OP_SQRT>(ctx, src0, src1, dst);
}
static void ggml_qnn_norm(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, static void ggml_qnn_norm(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) {} ggml_tensor *dst) {}
@ -392,38 +402,31 @@ static void ggml_qnn_argsort(ggml_backend_qnn_context *ctx, const ggml_tensor *s
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
} }
static void ggml_qnn_nop(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst) {
(void)src0;
(void)src1;
(void)dst;
}
qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() { qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() {
static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[] = { static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[] = {
nullptr, // GGML_OP_NONE nullptr, // GGML_OP_NONE
nullptr, // GGML_OP_DUP nullptr, // GGML_OP_DUP
ggml_qnn_add, // GGML_OP_ADD ggml_qnn_add, // GGML_OP_ADD
nullptr, // GGML_OP_ADD1 nullptr, // GGML_OP_ADD1
nullptr, // GGML_OP_ACC nullptr, // GGML_OP_ACC
nullptr, // GGML_OP_SUB nullptr, // GGML_OP_SUB
nullptr, // GGML_OP_MUL ggml_qnn_mul, // GGML_OP_MUL
nullptr, // GGML_OP_DIV nullptr, // GGML_OP_DIV
nullptr, // GGML_OP_SQR nullptr, // GGML_OP_SQR
nullptr, // GGML_OP_SQRT ggml_qnn_sqrt, // GGML_OP_SQRT
nullptr, // GGML_OP_LOG nullptr, // GGML_OP_LOG
nullptr, // GGML_OP_SUM nullptr, // GGML_OP_SUM
nullptr, // GGML_OP_SUM_ROWS nullptr, // GGML_OP_SUM_ROWS
nullptr, // GGML_OP_MEAN nullptr, // GGML_OP_MEAN
nullptr, // GGML_OP_ARGMAX nullptr, // GGML_OP_ARGMAX
nullptr, // GGML_OP_REPEAT nullptr, // GGML_OP_REPEAT
nullptr, // GGML_OP_REPEAT_BACK nullptr, // GGML_OP_REPEAT_BACK
nullptr, // GGML_OP_CONCAT nullptr, // GGML_OP_CONCAT
nullptr, // GGML_OP_SILU_BACK nullptr, // GGML_OP_SILU_BACK
nullptr, // GGML_OP_NORM nullptr, // GGML_OP_NORM
nullptr, // GGML_OP_RMS_NORM nullptr, // GGML_OP_RMS_NORM
nullptr, // GGML_OP_RMS_NORM_BACK nullptr, // GGML_OP_RMS_NORM_BACK
nullptr, // GGML_OP_GROUP_NORM nullptr, // GGML_OP_GROUP_NORM
ggml_qnn_mul_mat, // GGML_OP_MUL_MAT ggml_qnn_mul_mat, // GGML_OP_MUL_MAT
nullptr, // GGML_OP_MUL_MAT_ID nullptr, // GGML_OP_MUL_MAT_ID