fix: try fix tensor type error
This commit is contained in:
parent
28a00e5e6c
commit
27299463ae
3 changed files with 10 additions and 6 deletions
|
@ -74,7 +74,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra
|
||||||
std::array<Qnn_Tensor_t, _InputSize> qnn_input_tensors;
|
std::array<Qnn_Tensor_t, _InputSize> qnn_input_tensors;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(inputs[i]);
|
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(inputs[i]);
|
||||||
if (!tensor || !tensor->bind_to_graph(*graph)) {
|
if (!tensor || !tensor->bind_to_graph(*graph, true)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ bool qnn_bind_tensors_to_graph(qnn::ggml_qnn_graph<_InputSize, _OutputSize> *gra
|
||||||
std::array<Qnn_Tensor_t, _OutputSize> qnn_output_tensors;
|
std::array<Qnn_Tensor_t, _OutputSize> qnn_output_tensors;
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(outputs[i]);
|
auto tensor = qnn::ggml_qnn_tensor::from_ggml_tensor(outputs[i]);
|
||||||
if (!tensor || !tensor->bind_to_graph(*graph)) {
|
if (!tensor || !tensor->bind_to_graph(*graph, false)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,8 @@ public:
|
||||||
_dimensions[2] = (uint32_t)tensor->ne[2];
|
_dimensions[2] = (uint32_t)tensor->ne[2];
|
||||||
_dimensions[3] = (uint32_t)tensor->ne[3];
|
_dimensions[3] = (uint32_t)tensor->ne[3];
|
||||||
QNN_TENSOR_SET_DIMENSIONS(_qnn_tensor, _dimensions);
|
QNN_TENSOR_SET_DIMENSIONS(_qnn_tensor, _dimensions);
|
||||||
QNN_TENSOR_SET_TYPE(_qnn_tensor, device_tensortype_from_ggml_tensor(tensor));
|
auto qnn_tensor_type = device_tensortype_from_ggml_tensor(tensor);
|
||||||
|
QNN_TENSOR_SET_TYPE(_qnn_tensor, qnn_tensor_type);
|
||||||
QNN_TENSOR_SET_DATA_FORMAT(_qnn_tensor, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER);
|
QNN_TENSOR_SET_DATA_FORMAT(_qnn_tensor, QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER);
|
||||||
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
|
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
|
||||||
// TODO: set the quantizeParams base on the tensor type
|
// TODO: set the quantizeParams base on the tensor type
|
||||||
|
@ -54,11 +55,11 @@ public:
|
||||||
QNN_TENSOR_SET_CLIENT_BUF(_qnn_tensor, client_buf);
|
QNN_TENSOR_SET_CLIENT_BUF(_qnn_tensor, client_buf);
|
||||||
|
|
||||||
tensor->extra = this;
|
tensor->extra = this;
|
||||||
QNN_LOG_DEBUG("create tensor %s with device %d", _tensor_name.c_str(), device);
|
QNN_LOG_DEBUG("create tensor %s, device: %d, qnn_type: %d", _tensor_name.c_str(), device, (int)qnn_tensor_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t _InputSize, size_t _OutputSize>
|
template <size_t _InputSize, size_t _OutputSize>
|
||||||
bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph) {
|
bool bind_to_graph(ggml_qnn_graph<_InputSize, _OutputSize> &graph, bool is_input) {
|
||||||
if (!is_valid()) {
|
if (!is_valid()) {
|
||||||
QNN_LOG_WARN("tensor %s not valid", _tensor_name.c_str());
|
QNN_LOG_WARN("tensor %s not valid", _tensor_name.c_str());
|
||||||
return false;
|
return false;
|
||||||
|
@ -75,6 +76,9 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
|
||||||
|
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
|
||||||
|
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
|
||||||
Qnn_Tensor_t tensor = _qnn_tensor;
|
Qnn_Tensor_t tensor = _qnn_tensor;
|
||||||
if (!graph.create_graph_tensor(tensor)) {
|
if (!graph.create_graph_tensor(tensor)) {
|
||||||
QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str());
|
QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str());
|
||||||
|
|
|
@ -30,7 +30,7 @@ Qnn_DataType_t device_datatype_from_ggml_datatype(ggml_type ggml_type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Qnn_TensorType_t device_tensortype_from_ggml_tensor(ggml_tensor *ggml_tensor) {
|
Qnn_TensorType_t device_tensortype_from_ggml_tensor(ggml_tensor *ggml_tensor) {
|
||||||
Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;
|
Qnn_TensorType_t qnn_tensor_type = QNN_TENSOR_TYPE_NATIVE;
|
||||||
|
|
||||||
if (ggml_tensor->flags & GGML_TENSOR_FLAG_INPUT) {
|
if (ggml_tensor->flags & GGML_TENSOR_FLAG_INPUT) {
|
||||||
qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;
|
qnn_tensor_type = QNN_TENSOR_TYPE_APP_WRITE;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue