feat: update tensor name when bind to graph

This commit is contained in:
hongruichen 2024-07-20 16:57:51 +08:00
parent 5f3b1ae3b0
commit b173c4e061
2 changed files with 21 additions and 13 deletions

View file

@ -180,6 +180,7 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
auto it = graph_cache.find(graph_key);
graph_t *graph_ptr = nullptr;
if (it != graph_cache.end()) {
QNN_LOG_DEBUG("found graph %s in cache\n", graph_key.c_str());
graph_ptr = it->second.get();
} else {
auto graph =

View file

@ -29,14 +29,7 @@ public:
explicit ggml_qnn_tensor(ggml_tensor *tensor, QNNBackend device, std::shared_ptr<qnn_instance> qnn_instance) :
_tensor(tensor), _device(device), _qnn_instance(qnn_instance) {
_tensor_name = ggml_get_name(tensor);
if (_tensor_name.empty()) {
static std::atomic_uint32_t unnamed_tensor_count = 0;
char buffer[GGML_MAX_NAME] = {};
snprintf(buffer, sizeof(buffer), "unnamed_%d", (int)(unnamed_tensor_count++));
_tensor_name = buffer;
}
update_tensor_name();
QNN_TENSOR_SET_NAME(_qnn_tensor, _tensor_name.c_str());
_dimensions[0] = (uint32_t)tensor->ne[0];
_dimensions[1] = (uint32_t)tensor->ne[1];
@ -79,6 +72,7 @@ public:
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
update_tensor_name();
Qnn_Tensor_t tensor = _qnn_tensor;
if (!graph.create_graph_tensor(tensor)) {
QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str());
@ -116,15 +110,14 @@ public:
auto tensor_type = QNN_TENSOR_GET_TYPE(_qnn_tensor);
if (tensor_type != QNN_TENSOR_TYPE_APP_WRITE && tensor_type != QNN_TENSOR_TYPE_APP_READWRITE) {
QNN_LOG_WARN("tensor %s not writable", _tensor_name.c_str());
return false;
QNN_LOG_WARN("tensor %s type(%d) not WRITE", _tensor_name.c_str(), (int)tensor_type);
}
if (should_use_mem_handle()) {
if (_qnn_rpc_buffer) {
memcpy(_qnn_rpc_buffer, _tensor->data, ggml_nbytes(_tensor));
} else {
QNN_LOG_WARN("can't find rpcmem from qnn mem handle\n");
QNN_LOG_WARN("tensor %s: can't find rpcmem from qnn mem handle\n", _tensor_name.c_str());
return false;
}
}
@ -142,8 +135,7 @@ public:
auto tensor_type = QNN_TENSOR_GET_TYPE(_qnn_tensor);
if (tensor_type != QNN_TENSOR_TYPE_APP_READ && tensor_type != QNN_TENSOR_TYPE_APP_READWRITE) {
QNN_LOG_WARN("tensor %s not readable", _tensor_name.c_str());
return false;
QNN_LOG_WARN("tensor %s type(%d) not READ", _tensor_name.c_str(), (int)tensor_type);
}
if (should_use_mem_handle()) {
@ -190,6 +182,21 @@ private:
bool should_use_mem_handle() const { return _device == QNN_BACKEND_NPU; }
void update_tensor_name() {
auto *tensor_name = ggml_get_name(_tensor);
if (!strnlen(tensor_name, GGML_MAX_NAME)) {
if (_tensor_name.empty()) {
static std::atomic_uint32_t unnamed_tensor_count = 0;
char buffer[GGML_MAX_NAME] = {};
snprintf(buffer, sizeof(buffer), "unnamed_%d", (int)(unnamed_tensor_count++));
_tensor_name = buffer;
}
} else {
QNN_LOG_DEBUG("tensor name changed: %s -> %s", _tensor_name.c_str(), tensor_name);
_tensor_name = tensor_name;
}
}
const ggml_tensor *_tensor;
QNNBackend _device;
std::shared_ptr<qnn_instance> _qnn_instance;