From 0eb595cc6e7691d8fbf30b7b04cdf8dd7eb108e3 Mon Sep 17 00:00:00 2001 From: hongruichen Date: Fri, 12 Jul 2024 19:52:35 +0800 Subject: [PATCH] use table to simpilify the op mapping --- tests/ggml-qnn/ggml-qnn-ut.cpp | 69 +++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/tests/ggml-qnn/ggml-qnn-ut.cpp b/tests/ggml-qnn/ggml-qnn-ut.cpp index ff01e62f9..0c3fbf71e 100644 --- a/tests/ggml-qnn/ggml-qnn-ut.cpp +++ b/tests/ggml-qnn/ggml-qnn-ut.cpp @@ -327,6 +327,41 @@ static void show_usage() { ); } + +typedef ggml_tensor * (*ggml_op_binary_t)( + ggml_context * ctx, + ggml_tensor * a, + ggml_tensor * b); + +static constexpr const ggml_op_binary_t kBinaryOps[] = { + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + ggml_add, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + ggml_mul, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + nullptr, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM + ggml_mul_mat, // GGML_OP_MUL_MAT +}; + +static_assert(kBinaryOps[GGML_OP_MUL_MAT] == ggml_mul_mat, "ggml_mul_mat at wrong index, check kBinaryOps"); + static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) { int64_t n_begin_time = 0LL; int64_t n_end_time = 0LL; @@ -398,19 +433,15 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) { ggml_set_input(src0); ggml_set_input(src1); - switch (n_ggml_op_type) { - case GGML_OP_ADD: - dst = ggml_add(ctx, src0, src1); - break; - case GGML_OP_MUL_MAT: - dst = ggml_mul_mat(ctx, src0, src1); - break; - default: - QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type, - ggml_op_name((enum ggml_op) n_ggml_op_type)); - ggml_free(ctx); - ggml_backend_free(backend); - return 3; + auto binary_op = kBinaryOps[n_ggml_op_type]; + if (binary_op) { + dst = binary_op(ctx, src0, src1); + } else { + QNN_LOG_WARN("ggml op %d(%s) not supported", n_ggml_op_type, + ggml_op_name((enum ggml_op) n_ggml_op_type)); + ggml_free(ctx); + ggml_backend_free(backend); + return 3; } ggml_set_output(dst); @@ -473,6 +504,11 @@ static int qnn_op_ut(int num_threads, int n_backend_type, int n_ggml_op_type) { return 0; } +static const std::unordered_map kMapStringToGGMLOp = { + {"GGML_OP_ADD", GGML_OP_ADD}, + {"GGML_OP_MUL_MAT", GGML_OP_MUL_MAT}, +}; + int main(int argc, char * argv[]) { int num_threads = 4; int n_backend_type = QNN_BACKEND_CPU; @@ -481,10 +517,9 @@ int main(int argc, char * argv[]) { for (int i = 1; i < argc; i++) { if (0 == strcmp(argv[i], "-t")) { if (i + 1 < argc) { - if (0 == memcmp(argv[i + 1], "GGML_OP_ADD", 11)) { - n_ggml_op_type = GGML_OP_ADD; - } else if (0 == memcmp(argv[i + 1], "GGML_OP_MUL_MAT", 15)) { - n_ggml_op_type = GGML_OP_MUL_MAT; + auto it = kMapStringToGGMLOp.find(argv[i + 1]); + if (it != kMapStringToGGMLOp.end()) { + n_ggml_op_type = it->second; } else { show_usage(); return 1;