From 27299463ae74b8d72ce84780a61f25ad77634f0f Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 20 Jul 2024 14:23:44 +0800 Subject: [PATCH] fix: try fix tensor type error --- ggml/src/ggml-qnn/backend-ops.cpp | 4 ++-- ggml/src/ggml-qnn/tensor.hpp | 10 +++++++--- ggml/src/ggml-qnn/utils.cpp | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-qnn/backend-ops.cpp b/ggml/src/ggml-qnn/backend-ops.cpp index 6367e7c70..1e7920598 100644 --- a/ggml/src/ggml-qnn/backend-ops.cpp +++ b/ggml/src/ggml-qnn/backend-ops.cpp @@ -74,7 +74,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra std::array qnn_input_tensors; for (size_t i = 0; i < inputs.size(); ++i) { auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(inputs[i]); - if (!tensor || !tensor->bind_to_graph(*graph)) { + if (!tensor || !tensor->bind_to_graph(*graph, true)) { return false; } @@ -84,7 +84,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra std::array qnn_output_tensors; for (size_t i = 0; i < outputs.size(); ++i) { auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(outputs[i]); - if (!tensor || !tensor->bind_to_graph(*graph)) { + if (!tensor || !tensor->bind_to_graph(*graph, false)) { return false; } diff --git a/ggml/src/ggml-qnn/tensor.hpp b/ggml/src/ggml-qnn/tensor.hpp index e5dc436ad..9137b5d86 100644 --- a/ggml/src/ggml-qnn/tensor.hpp +++ b/ggml/src/ggml-qnn/tensor.hpp @@ -43,7 +43,8 @@ public: _dimensions[2] = (uint32_t)tensor->ne[2]; _dimensions[3] = (uint32_t)tensor->ne[3]; QNN_TENSOR_SET_DIMENSIONS(_qnn_tensor, _dimensions); - QNN_TENSOR_SET_TYPE(_qnn_tensor, device_tensortype_from_ggml_tensor(tensor)); + auto qnn_tensor_type = device_tensortype_from_ggml_tensor(tensor); + QNN_TENSOR_SET_TYPE(_qnn_tensor, qnn_tensor_type); QNN_TENSOR_SET_DATA_FORMAT(_qnn_tensor, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER); QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type)); // TODO: set the quantizeParams base on the tensor type @@ -54,11 +55,11 @@ public: QNN_TENSOR_SET_CLIENT_BUF(_qnn_tensor, client_buf); tensor->extra = this; - QNN_LOG_DEBUG("create tensor %s with device %d", _tensor_name.c_str(), device); + QNN_LOG_DEBUG("create tensor %s, device: %d, qnn_type: %d", _tensor_name.c_str(), device, (int)qnn_tensor_type); } template - bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph) { + bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph, bool is_input) { if (!is_valid()) { QNN_LOG_WARN("tensor %s not valid", _tensor_name.c_str()); return false; @@ -75,6 +76,9 @@ public: } } + Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ; + QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type); + QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type); Qnn_Tensor_t tensor = _qnn_tensor; if (!graph.create_graph_tensor(tensor)) { QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str()); diff --git a/ggml/src/ggml-qnn/utils.cpp b/ggml/src/ggml-qnn/utils.cpp index 70a898b95..820b72b89 100644 --- a/ggml/src/ggml-qnn/utils.cpp +++ b/ggml/src/ggml-qnn/utils.cpp @@ -30,7 +30,7 @@ Qnn_DataType_t device_datatype_from_ggml_datatype(ggml_type ggml_type) { } Qnn_TensorType_t device_tensortype_from_ggml_tensor(ggml_tensor *ggml_tensor) { - Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE; + Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_NATIVE; if (ggml_tensor->flags & GGML_TENSOR_FLAG_INPUT) { qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;