fix: try fix the tensor rank of mul mat
This commit is contained in:
parent
6cc7432b37
commit
47f6e02eda
4 changed files with 26 additions and 13 deletions
|
@ -133,6 +133,7 @@ qnn::ggml_op_constructor_t generate_common_op_constructor(const std::string &op_
|
||||||
scalar.dataType = QNN_DATATYPE_BOOL_8;
|
scalar.dataType = QNN_DATATYPE_BOOL_8;
|
||||||
scalar.bool8Value = true;
|
scalar.bool8Value = true;
|
||||||
config->add_scalar_param(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar);
|
config->add_scalar_param(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar);
|
||||||
|
QNN_LOG_DEBUG("add scalar param %s\n", QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0);
|
||||||
return config;
|
return config;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,6 +91,14 @@ public:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the max tensor rank
|
||||||
|
for (auto tensor : tensor_inputs) {
|
||||||
|
_tensor_rank = std::max(_tensor_rank, ggml_n_dims(tensor));
|
||||||
|
}
|
||||||
|
for (auto tensor : tensor_outputs) {
|
||||||
|
_tensor_rank = std::max(_tensor_rank, ggml_n_dims(tensor));
|
||||||
|
}
|
||||||
|
|
||||||
QNN_LOG_DEBUG("graph name %s, build_graph start", _graph_name.c_str());
|
QNN_LOG_DEBUG("graph name %s, build_graph start", _graph_name.c_str());
|
||||||
_tensor_inputs.resize(tensor_inputs.size());
|
_tensor_inputs.resize(tensor_inputs.size());
|
||||||
for (size_t i = 0; i < tensor_inputs.size(); i++) {
|
for (size_t i = 0; i < tensor_inputs.size(); i++) {
|
||||||
|
@ -99,7 +107,7 @@ public:
|
||||||
auto qnn_tensor =
|
auto qnn_tensor =
|
||||||
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
|
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
|
||||||
auto *ggml_tensor = tensor_inputs[i];
|
auto *ggml_tensor = tensor_inputs[i];
|
||||||
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, true)) {
|
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, true, _tensor_rank)) {
|
||||||
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -114,7 +122,7 @@ public:
|
||||||
auto qnn_tensor =
|
auto qnn_tensor =
|
||||||
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
|
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
|
||||||
auto *ggml_tensor = tensor_outputs[i];
|
auto *ggml_tensor = tensor_outputs[i];
|
||||||
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, false)) {
|
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, false, _tensor_rank)) {
|
||||||
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -156,7 +164,7 @@ public:
|
||||||
GGML_ASSERT(tensor_outputs.size() == _tensor_outputs.size());
|
GGML_ASSERT(tensor_outputs.size() == _tensor_outputs.size());
|
||||||
for (size_t i = 0; i < tensor_inputs.size(); i++) {
|
for (size_t i = 0; i < tensor_inputs.size(); i++) {
|
||||||
auto *ggml_tensor = tensor_inputs[i];
|
auto *ggml_tensor = tensor_inputs[i];
|
||||||
if (!_tensor_inputs[i]->bind_ggml_tensor(ggml_tensor, true)) {
|
if (!_tensor_inputs[i]->bind_ggml_tensor(ggml_tensor, true, _tensor_rank)) {
|
||||||
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -164,7 +172,7 @@ public:
|
||||||
|
|
||||||
for (size_t i = 0; i < tensor_outputs.size(); i++) {
|
for (size_t i = 0; i < tensor_outputs.size(); i++) {
|
||||||
auto *ggml_tensor = tensor_outputs[i];
|
auto *ggml_tensor = tensor_outputs[i];
|
||||||
if (!_tensor_outputs[i]->bind_ggml_tensor(ggml_tensor, false)) {
|
if (!_tensor_outputs[i]->bind_ggml_tensor(ggml_tensor, false, _tensor_rank)) {
|
||||||
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -216,6 +224,7 @@ private:
|
||||||
std::vector<std::shared_ptr<ggml_qnn_tensor>> _tensor_outputs;
|
std::vector<std::shared_ptr<ggml_qnn_tensor>> _tensor_outputs;
|
||||||
std::unique_ptr<ggml_qnn_op_config> _op_config;
|
std::unique_ptr<ggml_qnn_op_config> _op_config;
|
||||||
std::vector<Qnn_Param_t> _param_types;
|
std::vector<Qnn_Param_t> _param_types;
|
||||||
|
int _tensor_rank = 0;
|
||||||
|
|
||||||
DISABLE_COPY(ggml_qnn_graph);
|
DISABLE_COPY(ggml_qnn_graph);
|
||||||
DISABLE_MOVE(ggml_qnn_graph);
|
DISABLE_MOVE(ggml_qnn_graph);
|
||||||
|
|
|
@ -36,7 +36,7 @@ public:
|
||||||
param.paramType = QNN_PARAMTYPE_SCALAR;
|
param.paramType = QNN_PARAMTYPE_SCALAR;
|
||||||
param.name = _param_names.back().c_str();
|
param.name = _param_names.back().c_str();
|
||||||
param.scalarParam = scalar;
|
param.scalarParam = scalar;
|
||||||
_param_types.push_back(param);
|
_parameters.push_back(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Qnn_Tensor_t> &get_qnn_input_tensors() { return _qnn_tensor_inputs; }
|
std::vector<Qnn_Tensor_t> &get_qnn_input_tensors() { return _qnn_tensor_inputs; }
|
||||||
|
@ -49,8 +49,8 @@ public:
|
||||||
op_config.name = _name.c_str();
|
op_config.name = _name.c_str();
|
||||||
op_config.packageName = _package_name.c_str();
|
op_config.packageName = _package_name.c_str();
|
||||||
op_config.typeName = _op_type.c_str();
|
op_config.typeName = _op_type.c_str();
|
||||||
op_config.numOfParams = (uint32_t)_param_types.size();
|
op_config.numOfParams = (uint32_t)_parameters.size();
|
||||||
op_config.params = _param_types.data();
|
op_config.params = _parameters.data();
|
||||||
op_config.numOfInputs = (uint32_t)_qnn_tensor_inputs.size();
|
op_config.numOfInputs = (uint32_t)_qnn_tensor_inputs.size();
|
||||||
op_config.inputTensors = _qnn_tensor_inputs.data();
|
op_config.inputTensors = _qnn_tensor_inputs.data();
|
||||||
op_config.numOfOutputs = (uint32_t)_qnn_tensor_outputs.size();
|
op_config.numOfOutputs = (uint32_t)_qnn_tensor_outputs.size();
|
||||||
|
@ -64,7 +64,7 @@ private:
|
||||||
std::string _op_type;
|
std::string _op_type;
|
||||||
std::vector<Qnn_Tensor_t> _qnn_tensor_inputs;
|
std::vector<Qnn_Tensor_t> _qnn_tensor_inputs;
|
||||||
std::vector<Qnn_Tensor_t> _qnn_tensor_outputs;
|
std::vector<Qnn_Tensor_t> _qnn_tensor_outputs;
|
||||||
std::vector<Qnn_Param_t> _param_types;
|
std::vector<Qnn_Param_t> _parameters;
|
||||||
std::vector<std::string> _param_names;
|
std::vector<std::string> _param_names;
|
||||||
|
|
||||||
DISABLE_COPY(ggml_qnn_op_config);
|
DISABLE_COPY(ggml_qnn_op_config);
|
||||||
|
|
|
@ -29,7 +29,7 @@ public:
|
||||||
|
|
||||||
~ggml_qnn_tensor() { _qnn_rpc_buffer.reset(); }
|
~ggml_qnn_tensor() { _qnn_rpc_buffer.reset(); }
|
||||||
|
|
||||||
bool bind_ggml_tensor(ggml_tensor *tensor, bool is_input) {
|
bool bind_ggml_tensor(ggml_tensor *tensor, bool is_input, int prev_max_rank) {
|
||||||
if (_tensor) {
|
if (_tensor) {
|
||||||
if (_tensor != tensor) {
|
if (_tensor != tensor) {
|
||||||
QNN_LOG_WARN("tensor %s has been bound to another ggml tensor %s", _tensor_name.c_str(),
|
QNN_LOG_WARN("tensor %s has been bound to another ggml tensor %s", _tensor_name.c_str(),
|
||||||
|
@ -41,7 +41,7 @@ public:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
update_params_from_ggml_tensor(tensor);
|
update_params_from_ggml_tensor(tensor, prev_max_rank);
|
||||||
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
|
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
|
||||||
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
|
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
|
||||||
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
|
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
|
||||||
|
@ -54,8 +54,10 @@ public:
|
||||||
QNN_LOG_WARN("create graph tensor failed, tensor %s, error: %d\n", _tensor_name.c_str(), error);
|
QNN_LOG_WARN("create graph tensor failed, tensor %s, error: %d\n", _tensor_name.c_str(), error);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
QNN_TENSOR_SET_ID(_qnn_tensor, QNN_TENSOR_GET_ID(qnn_tensor));
|
QNN_TENSOR_SET_ID(_qnn_tensor, QNN_TENSOR_GET_ID(qnn_tensor));
|
||||||
QNN_LOG_DEBUG("create graph tensor %s, id: %d", _tensor_name.c_str(), QNN_TENSOR_GET_ID(qnn_tensor));
|
QNN_LOG_DEBUG("create graph tensor %s, id: %d, rank: %d", _tensor_name.c_str(),
|
||||||
|
QNN_TENSOR_GET_ID(qnn_tensor), QNN_TENSOR_GET_RANK(qnn_tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (should_use_mem_handle()) {
|
if (should_use_mem_handle()) {
|
||||||
|
@ -166,14 +168,15 @@ private:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void update_params_from_ggml_tensor(ggml_tensor *tensor) {
|
void update_params_from_ggml_tensor(ggml_tensor *tensor, int prev_max_rank) {
|
||||||
_dimensions[0] = (uint32_t)tensor->ne[0];
|
_dimensions[0] = (uint32_t)tensor->ne[0];
|
||||||
_dimensions[1] = (uint32_t)tensor->ne[1];
|
_dimensions[1] = (uint32_t)tensor->ne[1];
|
||||||
_dimensions[2] = (uint32_t)tensor->ne[2];
|
_dimensions[2] = (uint32_t)tensor->ne[2];
|
||||||
_dimensions[3] = (uint32_t)tensor->ne[3];
|
_dimensions[3] = (uint32_t)tensor->ne[3];
|
||||||
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
|
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
|
||||||
// TODO: set the quantizeParams base on the tensor type
|
// TODO: set the quantizeParams base on the tensor type
|
||||||
QNN_TENSOR_SET_RANK(_qnn_tensor, (uint32_t)ggml_n_dims(tensor));
|
|
||||||
|
QNN_TENSOR_SET_RANK(_qnn_tensor, (uint32_t)std::max(prev_max_rank, ggml_n_dims(tensor)));
|
||||||
|
|
||||||
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_RAW);
|
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_RAW);
|
||||||
Qnn_ClientBuffer_t client_buf = {};
|
Qnn_ClientBuffer_t client_buf = {};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue