From c1e2283887c3fb6d09b2b3fdedd3847f1060ddfa Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 13 Jul 2024 10:55:36 +0800 Subject: [PATCH] expose op at unit test --- tests/ggml-qnn/ggml-qnn-ut.cpp | 50 ++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/ggml-qnn/ggml-qnn-ut.cpp b/tests/ggml-qnn/ggml-qnn-ut.cpp index 96dfa2bcf..dea336966 100644 --- a/tests/ggml-qnn/ggml-qnn-ut.cpp +++ b/tests/ggml-qnn/ggml-qnn-ut.cpp @@ -327,21 +327,51 @@ static void show_usage() { ); } +typedef ggml_tensor * (*ggml_op_unary_t)( + ggml_context * ctx, + ggml_tensor * a); typedef ggml_tensor * (*ggml_op_binary_t)( ggml_context * ctx, ggml_tensor * a, ggml_tensor * b); +static constexpr const ggml_op_unary_t kUnaryOps[] = { + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + nullptr, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + nullptr, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + ggml_sqrt, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM + nullptr, // GGML_OP_MUL_MAT +}; + static constexpr const ggml_op_binary_t kBinaryOps[] = { nullptr, // GGML_OP_NONE nullptr, // GGML_OP_DUP ggml_add, // GGML_OP_ADD nullptr, // GGML_OP_ADD1 nullptr, // GGML_OP_ACC - nullptr, // GGML_OP_SUB + ggml_sub, // GGML_OP_SUB ggml_mul, // GGML_OP_MUL - nullptr, // GGML_OP_DIV + ggml_div, // GGML_OP_DIV nullptr, // GGML_OP_SQR nullptr, // GGML_OP_SQRT nullptr, // GGML_OP_LOG @@ -433,8 +463,11 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) { ggml_set_input(src0); ggml_set_input(src1); + auto unary_op = kUnaryOps[n_ggml_op_type]; auto binary_op = kBinaryOps[n_ggml_op_type]; - if (binary_op) { + if (unary_op) { + dst = unary_op(ctx, src0); + } else if (binary_op) { dst = binary_op(ctx, src0, src1); } else { QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type, @@ -504,10 +537,15 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) { return 0; } +#define DEFINE_OP(op) { #op, op } + static const std::unordered_map kMapStringToGGMLOp = { - {"GGML_OP_ADD", GGML_OP_ADD}, - {"GGML_OP_MUL_MAT", GGML_OP_MUL_MAT}, - {"GGML_OP_MUL", GGML_OP_MUL}, + DEFINE_OP(GGML_OP_ADD), + DEFINE_OP(GGML_OP_SUB), + DEFINE_OP(GGML_OP_MUL), + DEFINE_OP(GGML_OP_DIV), + DEFINE_OP(GGML_OP_SQRT), + DEFINE_OP(GGML_OP_MUL_MAT), }; int main(int argc, char * argv[]) {