add op param to add_nodes
This commit is contained in:
parent
4b2ee61f62
commit
a688ed324b
1 changed files with 9 additions and 3 deletions
|
@ -18,7 +18,7 @@ public:
|
||||||
|
|
||||||
explicit ggml_qnn_graph(const std::string &graph_name, QNNBackend device, Qnn_ContextHandle_t qnn_context,
|
explicit ggml_qnn_graph(const std::string &graph_name, QNNBackend device, Qnn_ContextHandle_t qnn_context,
|
||||||
QNN_INTERFACE_VER_TYPE qnn_interface, size_t vtcm_size_in_mb) :
|
QNN_INTERFACE_VER_TYPE qnn_interface, size_t vtcm_size_in_mb) :
|
||||||
_device(device), _qnn_interface(qnn_interface) {
|
_graph_name(graph_name), _device(device), _qnn_interface(qnn_interface) {
|
||||||
QNN_LOG_INFO("graph name %s", graph_name.c_str());
|
QNN_LOG_INFO("graph name %s", graph_name.c_str());
|
||||||
|
|
||||||
Qnn_ErrorHandle_t error = QNN_SUCCESS;
|
Qnn_ErrorHandle_t error = QNN_SUCCESS;
|
||||||
|
@ -74,7 +74,8 @@ public:
|
||||||
_graph_handle = graph_handle;
|
_graph_handle = graph_handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool add_nodes(const input_tensor_array_t &tensor_inputs, const output_tensor_array_t &tensor_outputs) {
|
bool add_nodes(const std::string &op_name, const input_tensor_array_t &tensor_inputs,
|
||||||
|
const output_tensor_array_t &tensor_outputs) {
|
||||||
if (!is_valid()) {
|
if (!is_valid()) {
|
||||||
QNN_LOG_ERROR("Invalid graph\n");
|
QNN_LOG_ERROR("Invalid graph\n");
|
||||||
return false;
|
return false;
|
||||||
|
@ -82,7 +83,7 @@ public:
|
||||||
|
|
||||||
Qnn_Param_t qnn_params[] = {};
|
Qnn_Param_t qnn_params[] = {};
|
||||||
Qnn_OpConfig_t op_config = { QNN_OPCONFIG_VERSION_1,
|
Qnn_OpConfig_t op_config = { QNN_OPCONFIG_VERSION_1,
|
||||||
.v1 = { "ggml_op_add", QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, 0,
|
.v1 = { _graph_name.c_str(), QNN_OP_PACKAGE_NAME_QTI_AISW, op_name.c_str(), 0,
|
||||||
qnn_params, _tensor_inputs.size(), _tensor_inputs.data(),
|
qnn_params, _tensor_inputs.size(), _tensor_inputs.data(),
|
||||||
_tensor_outputs.size(), _tensor_outputs.data() } };
|
_tensor_outputs.size(), _tensor_outputs.data() } };
|
||||||
auto error = _qnn_interface.graphAddNode(_graph_handle, op_config);
|
auto error = _qnn_interface.graphAddNode(_graph_handle, op_config);
|
||||||
|
@ -122,6 +123,7 @@ public:
|
||||||
Qnn_GraphHandle_t get_graph_handler() const { return _graph_handle; }
|
Qnn_GraphHandle_t get_graph_handler() const { return _graph_handle; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
const std::string _graph_name;
|
||||||
const QNNBackend _device;
|
const QNNBackend _device;
|
||||||
const QNN_INTERFACE_VER_TYPE _qnn_interface;
|
const QNN_INTERFACE_VER_TYPE _qnn_interface;
|
||||||
Qnn_GraphHandle_t _graph_handle = nullptr;
|
Qnn_GraphHandle_t _graph_handle = nullptr;
|
||||||
|
@ -133,4 +135,8 @@ private:
|
||||||
ggml_qnn_graph(ggml_qnn_graph &&) = delete;
|
ggml_qnn_graph(ggml_qnn_graph &&) = delete;
|
||||||
void operator=(ggml_qnn_graph &&) = delete;
|
void operator=(ggml_qnn_graph &&) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ggml_qnn_graph_binary = ggml_qnn_graph<2, 1>;
|
||||||
|
using ggml_qnn_graph_unary = ggml_qnn_graph<1, 1>;
|
||||||
|
|
||||||
} // namespace qnn
|
} // namespace qnn
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue