diff --git a/ggml/src/ggml-qnn/graph.hpp b/ggml/src/ggml-qnn/graph.hpp index f2c27aeb3..700114d6f 100644 --- a/ggml/src/ggml-qnn/graph.hpp +++ b/ggml/src/ggml-qnn/graph.hpp @@ -18,7 +18,7 @@ public: explicit ggml_qnn_graph(const std::string &graph_name, QNNBackend device, Qnn_ContextHandle_t qnn_context, QNN_INTERFACE_VER_TYPE qnn_interface, size_t vtcm_size_in_mb) : - _device(device), _qnn_interface(qnn_interface) { + _graph_name(graph_name), _device(device), _qnn_interface(qnn_interface) { QNN_LOG_INFO("graph name %s", graph_name.c_str()); Qnn_ErrorHandle_t error = QNN_SUCCESS; @@ -74,7 +74,8 @@ public: _graph_handle = graph_handle; } - bool add_nodes(const input_tensor_array_t &tensor_inputs, const output_tensor_array_t &tensor_outputs) { + bool add_nodes(const std::string &op_name, const input_tensor_array_t &tensor_inputs, + const output_tensor_array_t &tensor_outputs) { if (!is_valid()) { QNN_LOG_ERROR("Invalid graph\n"); return false; @@ -82,7 +83,7 @@ public: Qnn_Param_t qnn_params[] = {}; Qnn_OpConfig_t op_config = { QNN_OPCONFIG_VERSION_1, - .v1 = { "ggml_op_add", QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, 0, + .v1 = { _graph_name.c_str(), QNN_OP_PACKAGE_NAME_QTI_AISW, op_name.c_str(), 0, qnn_params, _tensor_inputs.size(), _tensor_inputs.data(), _tensor_outputs.size(), _tensor_outputs.data() } }; auto error = _qnn_interface.graphAddNode(_graph_handle, op_config); @@ -122,6 +123,7 @@ public: Qnn_GraphHandle_t get_graph_handler() const { return _graph_handle; } private: + const std::string _graph_name; const QNNBackend _device; const QNN_INTERFACE_VER_TYPE _qnn_interface; Qnn_GraphHandle_t _graph_handle = nullptr; @@ -133,4 +135,8 @@ private: ggml_qnn_graph(ggml_qnn_graph &&) = delete; void operator=(ggml_qnn_graph &&) = delete; }; + +using ggml_qnn_graph_binary = ggml_qnn_graph<2, 1>; +using ggml_qnn_graph_unary = ggml_qnn_graph<1, 1>; + } // namespace qnn