fix: try fix tensor type error

This commit is contained in:
hongruichen 2024-07-20 14:23:44 +08:00
parent 28a00e5e6c
commit 27299463ae
3 changed files with 10 additions and 6 deletions

View file

@ -74,7 +74,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra
std::array<Qnn_Tensor_t, _InputSize> qnn_input_tensors;
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(inputs[i]);
if (!tensor || !tensor->bind_to_graph(*graph)) {
if (!tensor || !tensor->bind_to_graph(*graph, true)) {
return false;
}
@ -84,7 +84,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra
std::array<Qnn_Tensor_t, _OutputSize> qnn_output_tensors;
for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(outputs[i]);
if (!tensor || !tensor->bind_to_graph(*graph)) {
if (!tensor || !tensor->bind_to_graph(*graph, false)) {
return false;
}

View file

@ -43,7 +43,8 @@ public:
_dimensions[2] = (uint32_t)tensor->ne[2];
_dimensions[3] = (uint32_t)tensor->ne[3];
QNN_TENSOR_SET_DIMENSIONS(_qnn_tensor, _dimensions);
QNN_TENSOR_SET_TYPE(_qnn_tensor, device_tensortype_from_ggml_tensor(tensor));
auto qnn_tensor_type = device_tensortype_from_ggml_tensor(tensor);
QNN_TENSOR_SET_TYPE(_qnn_tensor, qnn_tensor_type);
QNN_TENSOR_SET_DATA_FORMAT(_qnn_tensor, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER);
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
// TODO: set the quantizeParams base on the tensor type
@ -54,11 +55,11 @@ public:
QNN_TENSOR_SET_CLIENT_BUF(_qnn_tensor, client_buf);
tensor->extra = this;
QNN_LOG_DEBUG("create tensor %s with device %d", _tensor_name.c_str(), device);
QNN_LOG_DEBUG("create tensor %s, device: %d, qnn_type: %d", _tensor_name.c_str(), device, (int)qnn_tensor_type);
}
template <size_t _InputSize, size_t _OutputSize>
bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph) {
bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph, bool is_input) {
if (!is_valid()) {
QNN_LOG_WARN("tensor %s not valid", _tensor_name.c_str());
return false;
@ -75,6 +76,9 @@ public:
}
}
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
Qnn_Tensor_t tensor = _qnn_tensor;
if (!graph.create_graph_tensor(tensor)) {
QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str());

View file

@ -30,7 +30,7 @@ Qnn_DataType_t device_datatype_from_ggml_datatype(ggml_type ggml_type) {
}
Qnn_TensorType_t device_tensortype_from_ggml_tensor(ggml_tensor *ggml_tensor) {
Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;
Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_NATIVE;
if (ggml_tensor->flags & GGML_TENSOR_FLAG_INPUT) {
qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;