add op param to add_nodes

This commit is contained in:
hongruichen 2024-07-05 13:07:48 +08:00
parent 4b2ee61f62
commit a688ed324b

View file

@ -18,7 +18,7 @@ public:
explicit ggml_qnn_graph(const std::string &graph_name, QNNBackend device, Qnn_ContextHandle_t qnn_context, explicit ggml_qnn_graph(const std::string &graph_name, QNNBackend device, Qnn_ContextHandle_t qnn_context,
QNN_INTERFACE_VER_TYPE qnn_interface, size_t vtcm_size_in_mb) : QNN_INTERFACE_VER_TYPE qnn_interface, size_t vtcm_size_in_mb) :
_device(device), _qnn_interface(qnn_interface) { _graph_name(graph_name), _device(device), _qnn_interface(qnn_interface) {
QNN_LOG_INFO("graph name %s", graph_name.c_str()); QNN_LOG_INFO("graph name %s", graph_name.c_str());
Qnn_ErrorHandle_t error = QNN_SUCCESS; Qnn_ErrorHandle_t error = QNN_SUCCESS;
@ -74,7 +74,8 @@ public:
_graph_handle = graph_handle; _graph_handle = graph_handle;
} }
bool add_nodes(const input_tensor_array_t &tensor_inputs, const output_tensor_array_t &tensor_outputs) { bool add_nodes(const std::string &op_name, const input_tensor_array_t &tensor_inputs,
const output_tensor_array_t &tensor_outputs) {
if (!is_valid()) { if (!is_valid()) {
QNN_LOG_ERROR("Invalid graph\n"); QNN_LOG_ERROR("Invalid graph\n");
return false; return false;
@ -82,7 +83,7 @@ public:
Qnn_Param_t qnn_params[] = {}; Qnn_Param_t qnn_params[] = {};
Qnn_OpConfig_t op_config = { QNN_OPCONFIG_VERSION_1, Qnn_OpConfig_t op_config = { QNN_OPCONFIG_VERSION_1,
.v1 = { "ggml_op_add", QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, 0, .v1 = { _graph_name.c_str(), QNN_OP_PACKAGE_NAME_QTI_AISW, op_name.c_str(), 0,
qnn_params, _tensor_inputs.size(), _tensor_inputs.data(), qnn_params, _tensor_inputs.size(), _tensor_inputs.data(),
_tensor_outputs.size(), _tensor_outputs.data() } }; _tensor_outputs.size(), _tensor_outputs.data() } };
auto error = _qnn_interface.graphAddNode(_graph_handle, op_config); auto error = _qnn_interface.graphAddNode(_graph_handle, op_config);
@ -122,6 +123,7 @@ public:
Qnn_GraphHandle_t get_graph_handler() const { return _graph_handle; } Qnn_GraphHandle_t get_graph_handler() const { return _graph_handle; }
private: private:
const std::string _graph_name;
const QNNBackend _device; const QNNBackend _device;
const QNN_INTERFACE_VER_TYPE _qnn_interface; const QNN_INTERFACE_VER_TYPE _qnn_interface;
Qnn_GraphHandle_t _graph_handle = nullptr; Qnn_GraphHandle_t _graph_handle = nullptr;
@ -133,4 +135,8 @@ private:
ggml_qnn_graph(ggml_qnn_graph &&) = delete; ggml_qnn_graph(ggml_qnn_graph &&) = delete;
void operator=(ggml_qnn_graph &&) = delete; void operator=(ggml_qnn_graph &&) = delete;
}; };
using ggml_qnn_graph_binary = ggml_qnn_graph<2, 1>;
using ggml_qnn_graph_unary = ggml_qnn_graph<1, 1>;
} // namespace qnn } // namespace qnn