diff --git a/ggml/src/ggml-qnn/backend-ops.cpp b/ggml/src/ggml-qnn/backend-ops.cpp index 2627e23fd..5871a7b6e 100644 --- a/ggml/src/ggml-qnn/backend-ops.cpp +++ b/ggml/src/ggml-qnn/backend-ops.cpp @@ -120,7 +120,7 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c graph_ptr = it->second.get(); } else { std::string graph_name = graph_key + "_" + std::to_string(ctx->threads); - for (auto &input: inputs) { + for (auto &input : inputs) { graph_name += "_"; graph_name += input->name; } @@ -143,6 +143,116 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c return graph_ptr; } +constexpr const char *kGgmlOpToQnnOp[] = { + nullptr, // GGML_OP_NONE + nullptr, // GGML_OP_DUP + QNN_OP_ELEMENT_WISE_ADD, // GGML_OP_ADD + nullptr, // GGML_OP_ADD1 + nullptr, // GGML_OP_ACC + nullptr, // GGML_OP_SUB + nullptr, // GGML_OP_MUL + nullptr, // GGML_OP_DIV + nullptr, // GGML_OP_SQR + nullptr, // GGML_OP_SQRT + nullptr, // GGML_OP_LOG + nullptr, // GGML_OP_SUM + nullptr, // GGML_OP_SUM_ROWS + nullptr, // GGML_OP_MEAN + nullptr, // GGML_OP_ARGMAX + nullptr, // GGML_OP_REPEAT + nullptr, // GGML_OP_REPEAT_BACK + nullptr, // GGML_OP_CONCAT + nullptr, // GGML_OP_SILU_BACK + nullptr, // GGML_OP_NORM + nullptr, // GGML_OP_RMS_NORM + nullptr, // GGML_OP_RMS_NORM_BACK + nullptr, // GGML_OP_GROUP_NORM + + QNN_OP_MAT_MUL, // GGML_OP_MUL_MAT + nullptr, // GGML_OP_MUL_MAT_ID + nullptr, // GGML_OP_OUT_PROD + + nullptr, // GGML_OP_SCALE + nullptr, // GGML_OP_SET + nullptr, // GGML_OP_CPY + nullptr, // GGML_OP_CONT + nullptr, // GGML_OP_RESHAPE + nullptr, // GGML_OP_VIEW + nullptr, // GGML_OP_PERMUTE + nullptr, // GGML_OP_TRANSPOSE + nullptr, // GGML_OP_GET_ROWS + nullptr, // GGML_OP_GET_ROWS_BACK + nullptr, // GGML_OP_DIAG + nullptr, // GGML_OP_DIAG_MASK_INF + nullptr, // GGML_OP_DIAG_MASK_ZERO + nullptr, // GGML_OP_SOFT_MAX + nullptr, // GGML_OP_SOFT_MAX_BACK + nullptr, // GGML_OP_ROPE + nullptr, // GGML_OP_ROPE_BACK + nullptr, // GGML_OP_CLAMP + nullptr, // GGML_OP_CONV_TRANSPOSE_1D + nullptr, // GGML_OP_IM2COL + nullptr, // GGML_OP_CONV_TRANSPOSE_2D + nullptr, // GGML_OP_POOL_1D + nullptr, // GGML_OP_POOL_2D + nullptr, // GGML_OP_UPSCALE + nullptr, // GGML_OP_PAD + nullptr, // GGML_OP_ARANGE + nullptr, // GGML_OP_TIMESTEP_EMBEDDING + nullptr, // GGML_OP_ARGSORT + nullptr, // GGML_OP_LEAKY_RELU + + nullptr, // GGML_OP_FLASH_ATTN_EXT + nullptr, // GGML_OP_FLASH_ATTN_BACK + nullptr, // GGML_OP_SSM_CONV + nullptr, // GGML_OP_SSM_SCAN + nullptr, // GGML_OP_WIN_PART + nullptr, // GGML_OP_WIN_UNPART + nullptr, // GGML_OP_GET_REL_POS + nullptr, // GGML_OP_ADD_REL_POS + + nullptr, // GGML_OP_UNARY + + nullptr, // GGML_OP_MAP_UNARY + nullptr, // GGML_OP_MAP_BINARY + + nullptr, // GGML_OP_MAP_CUSTOM1_F32 + nullptr, // GGML_OP_MAP_CUSTOM2_F32 + nullptr, // GGML_OP_MAP_CUSTOM3_F32 + + nullptr, // GGML_OP_MAP_CUSTOM1 + nullptr, // GGML_OP_MAP_CUSTOM2 + nullptr, // GGML_OP_MAP_CUSTOM3 + + nullptr, // GGML_OP_CROSS_ENTROPY_LOSS + nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK +}; + +static_assert(sizeof(kGgmlOpToQnnOp) / sizeof(kGgmlOpToQnnOp[0]) == GGML_OP_COUNT, + "GGML_OP_COUNT does not match the size of the ops table"); + +template +void qnn_binary_op_impl(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst) { + static_assert(kGgmlOpToQnnOp[_GgmlOp] != nullptr, "GGML_OP does not have a corresponding QNN_OP"); + + qnn::qnn_perf perf(ggml_op_name(_GgmlOp)); + perf.start(); + + bool succeed = false; + qnn::ggml_qnn_graph_binary *graph_ptr = + get_qnn_graph_from_cache<2, 1>(ctx, _GgmlOp, kGgmlOpToQnnOp[_GgmlOp], { src0, src1 }, { dst }); + if (graph_ptr) { + succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst }); + } + + if (!succeed) { + print_ggml_tensor(src0); + print_ggml_tensor(src1); + print_ggml_tensor(dst); + } +} + } // namespace #ifndef NDEBUG @@ -160,22 +270,7 @@ qnn::ggml_qnn_graph_binary *get_qnn_graph_from_cache(ggml_backend_qnn_context *c static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) { CHECK_PARAMS(ctx, src0, src1, dst); - - qnn::qnn_perf perf("ggml_op_qnn_add"); - perf.start(); - - bool succeed = false; - qnn::ggml_qnn_graph_binary *graph_ptr = - get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_ADD, QNN_OP_ELEMENT_WISE_ADD, { src0, src1 }, { dst }); - if (graph_ptr) { - succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst }); - } - - if (!succeed) { - print_ggml_tensor(src0); - print_ggml_tensor(src1); - print_ggml_tensor(dst); - } + qnn_binary_op_impl(ctx, src0, src1, dst); } /* @@ -192,22 +287,7 @@ static void ggml_qnn_add(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, static void ggml_qnn_mul_mat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) { CHECK_PARAMS(ctx, src0, src1, dst); - - qnn::qnn_perf perf("ggml_op_qnn_mul_mat"); - perf.start(); - - bool succeed = false; - qnn::ggml_qnn_graph_binary *graph_ptr = - get_qnn_graph_from_cache<2, 1>(ctx, GGML_OP_MUL_MAT, QNN_OP_MAT_MUL, { src0, src1 }, { dst }); - if (graph_ptr) { - succeed = execute_graph<2, 1>(graph_ptr, { src0, src1 }, { dst }); - } - - if (!succeed) { - print_ggml_tensor(src0); - print_ggml_tensor(src1); - print_ggml_tensor(dst); - } + qnn_binary_op_impl(ctx, src0, src1, dst); } static void ggml_qnn_repeat(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, const ggml_tensor *src1, @@ -320,7 +400,7 @@ static void ggml_qnn_nop(ggml_backend_qnn_context *ctx, const ggml_tensor *src0, } qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() { - static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[GGML_OP_COUNT] = { + static constexpr const qnn::ggml_qnn_op_t kQnnOpsTable[] = { nullptr, // GGML_OP_NONE nullptr, // GGML_OP_DUP ggml_qnn_add, // GGML_OP_ADD @@ -405,5 +485,7 @@ qnn::ggml_qnn_op_array_t qnn::ggml_qnn_op_array() { nullptr, // GGML_OP_CROSS_ENTROPY_LOSS_BACK }; + static_assert(sizeof(kQnnOpsTable) / sizeof(kQnnOpsTable[0]) == GGML_OP_COUNT, + "GGML_OP_COUNT does not match the size of the ops table"); return kQnnOpsTable; }