fix: try fix the tensor rank of mul mat

This commit is contained in:
hongruichen 2024-07-31 22:44:21 +08:00
parent 6cc7432b37
commit 47f6e02eda
4 changed files with 26 additions and 13 deletions

View file

@ -133,6 +133,7 @@ qnn::ggml_op_constructor_t generate_common_op_constructor(const std::string &op_
scalar.dataType = QNN_DATATYPE_BOOL_8;
scalar.bool8Value = true;
config->add_scalar_param(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar);
QNN_LOG_DEBUG("add scalar param %s\n", QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0);
return config;
};
}

View file

@ -91,6 +91,14 @@ public:
return false;
}
// get the max tensor rank
for (auto tensor : tensor_inputs) {
_tensor_rank = std::max(_tensor_rank, ggml_n_dims(tensor));
}
for (auto tensor : tensor_outputs) {
_tensor_rank = std::max(_tensor_rank, ggml_n_dims(tensor));
}
QNN_LOG_DEBUG("graph name %s, build_graph start", _graph_name.c_str());
_tensor_inputs.resize(tensor_inputs.size());
for (size_t i = 0; i < tensor_inputs.size(); i++) {
@ -99,7 +107,7 @@ public:
auto qnn_tensor =
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
auto *ggml_tensor = tensor_inputs[i];
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, true)) {
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, true, _tensor_rank)) {
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
return false;
}
@ -114,7 +122,7 @@ public:
auto qnn_tensor =
std::make_shared<ggml_qnn_tensor>(std::string(buffer), _device, _graph_handle, _qnn_instance);
auto *ggml_tensor = tensor_outputs[i];
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, false)) {
if (!qnn_tensor->bind_ggml_tensor(ggml_tensor, false, _tensor_rank)) {
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
return false;
}
@ -156,7 +164,7 @@ public:
GGML_ASSERT(tensor_outputs.size() == _tensor_outputs.size());
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto *ggml_tensor = tensor_inputs[i];
if (!_tensor_inputs[i]->bind_ggml_tensor(ggml_tensor, true)) {
if (!_tensor_inputs[i]->bind_ggml_tensor(ggml_tensor, true, _tensor_rank)) {
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
return false;
}
@ -164,7 +172,7 @@ public:
for (size_t i = 0; i < tensor_outputs.size(); i++) {
auto *ggml_tensor = tensor_outputs[i];
if (!_tensor_outputs[i]->bind_ggml_tensor(ggml_tensor, false)) {
if (!_tensor_outputs[i]->bind_ggml_tensor(ggml_tensor, false, _tensor_rank)) {
QNN_LOG_ERROR("bind tensor %s failed\n", ggml_get_name(ggml_tensor));
return false;
}
@ -216,6 +224,7 @@ private:
std::vector<std::shared_ptr<ggml_qnn_tensor>> _tensor_outputs;
std::unique_ptr<ggml_qnn_op_config> _op_config;
std::vector<Qnn_Param_t> _param_types;
int _tensor_rank = 0;
DISABLE_COPY(ggml_qnn_graph);
DISABLE_MOVE(ggml_qnn_graph);

View file

@ -36,7 +36,7 @@ public:
param.paramType = QNN_PARAMTYPE_SCALAR;
param.name = _param_names.back().c_str();
param.scalarParam = scalar;
_param_types.push_back(param);
_parameters.push_back(param);
}
std::vector<Qnn_Tensor_t> &get_qnn_input_tensors() { return _qnn_tensor_inputs; }
@ -49,8 +49,8 @@ public:
op_config.name = _name.c_str();
op_config.packageName = _package_name.c_str();
op_config.typeName = _op_type.c_str();
op_config.numOfParams = (uint32_t)_param_types.size();
op_config.params = _param_types.data();
op_config.numOfParams = (uint32_t)_parameters.size();
op_config.params = _parameters.data();
op_config.numOfInputs = (uint32_t)_qnn_tensor_inputs.size();
op_config.inputTensors = _qnn_tensor_inputs.data();
op_config.numOfOutputs = (uint32_t)_qnn_tensor_outputs.size();
@ -64,7 +64,7 @@ private:
std::string _op_type;
std::vector<Qnn_Tensor_t> _qnn_tensor_inputs;
std::vector<Qnn_Tensor_t> _qnn_tensor_outputs;
std::vector<Qnn_Param_t> _param_types;
std::vector<Qnn_Param_t> _parameters;
std::vector<std::string> _param_names;
DISABLE_COPY(ggml_qnn_op_config);

View file

@ -29,7 +29,7 @@ public:
~ggml_qnn_tensor() { _qnn_rpc_buffer.reset(); }
bool bind_ggml_tensor(ggml_tensor *tensor, bool is_input) {
bool bind_ggml_tensor(ggml_tensor *tensor, bool is_input, int prev_max_rank) {
if (_tensor) {
if (_tensor != tensor) {
QNN_LOG_WARN("tensor %s has been bound to another ggml tensor %s", _tensor_name.c_str(),
@ -41,7 +41,7 @@ public:
return true;
}
update_params_from_ggml_tensor(tensor);
update_params_from_ggml_tensor(tensor, prev_max_rank);
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
@ -54,8 +54,10 @@ public:
QNN_LOG_WARN("create graph tensor failed, tensor %s, error: %d\n", _tensor_name.c_str(), error);
return false;
}
QNN_TENSOR_SET_ID(_qnn_tensor, QNN_TENSOR_GET_ID(qnn_tensor));
QNN_LOG_DEBUG("create graph tensor %s, id: %d", _tensor_name.c_str(), QNN_TENSOR_GET_ID(qnn_tensor));
QNN_LOG_DEBUG("create graph tensor %s, id: %d, rank: %d", _tensor_name.c_str(),
QNN_TENSOR_GET_ID(qnn_tensor), QNN_TENSOR_GET_RANK(qnn_tensor));
}
if (should_use_mem_handle()) {
@ -166,14 +168,15 @@ private:
return true;
}
void update_params_from_ggml_tensor(ggml_tensor *tensor) {
void update_params_from_ggml_tensor(ggml_tensor *tensor, int prev_max_rank) {
_dimensions[0] = (uint32_t)tensor->ne[0];
_dimensions[1] = (uint32_t)tensor->ne[1];
_dimensions[2] = (uint32_t)tensor->ne[2];
_dimensions[3] = (uint32_t)tensor->ne[3];
QNN_TENSOR_SET_DATA_TYPE(_qnn_tensor, device_datatype_from_ggml_datatype(tensor->type));
// TODO: set the quantizeParams base on the tensor type
QNN_TENSOR_SET_RANK(_qnn_tensor, (uint32_t)ggml_n_dims(tensor));
QNN_TENSOR_SET_RANK(_qnn_tensor, (uint32_t)std::max(prev_max_rank, ggml_n_dims(tensor)));
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_RAW);
Qnn_ClientBuffer_t client_buf = {};