feat: update tensor name when bind to graph
This commit is contained in:
parent
5f3b1ae3b0
commit
b173c4e061
2 changed files with 21 additions and 13 deletions
|
@ -180,6 +180,7 @@ qnn::ggml_qnn_graph<_InputSize, _OutputSize> *get_qnn_graph_from_cache(
|
|||
auto it = graph_cache.find(graph_key);
|
||||
graph_t *graph_ptr = nullptr;
|
||||
if (it != graph_cache.end()) {
|
||||
QNN_LOG_DEBUG("found graph %s in cache\n", graph_key.c_str());
|
||||
graph_ptr = it->second.get();
|
||||
} else {
|
||||
auto graph =
|
||||
|
|
|
@ -29,14 +29,7 @@ public:
|
|||
|
||||
explicit ggml_qnn_tensor(ggml_tensor *tensor, QNNBackend device, std::shared_ptr<qnn_instance> qnn_instance) :
|
||||
_tensor(tensor), _device(device), _qnn_instance(qnn_instance) {
|
||||
_tensor_name = ggml_get_name(tensor);
|
||||
if (_tensor_name.empty()) {
|
||||
static std::atomic_uint32_t unnamed_tensor_count = 0;
|
||||
char buffer[GGML_MAX_NAME] = {};
|
||||
snprintf(buffer, sizeof(buffer), "unnamed_%d", (int)(unnamed_tensor_count++));
|
||||
_tensor_name = buffer;
|
||||
}
|
||||
|
||||
update_tensor_name();
|
||||
QNN_TENSOR_SET_NAME(_qnn_tensor, _tensor_name.c_str());
|
||||
_dimensions[0] = (uint32_t)tensor->ne[0];
|
||||
_dimensions[1] = (uint32_t)tensor->ne[1];
|
||||
|
@ -79,6 +72,7 @@ public:
|
|||
Qnn_TensorType_t new_tensor_type = is_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_APP_READ;
|
||||
QNN_TENSOR_SET_TYPE(_qnn_tensor, new_tensor_type);
|
||||
QNN_LOG_INFO("tensor %s changed to type %d", _tensor_name.c_str(), new_tensor_type);
|
||||
update_tensor_name();
|
||||
Qnn_Tensor_t tensor = _qnn_tensor;
|
||||
if (!graph.create_graph_tensor(tensor)) {
|
||||
QNN_LOG_WARN("create graph tensor failed, tensor %s", _tensor_name.c_str());
|
||||
|
@ -116,15 +110,14 @@ public:
|
|||
|
||||
auto tensor_type = QNN_TENSOR_GET_TYPE(_qnn_tensor);
|
||||
if (tensor_type != QNN_TENSOR_TYPE_APP_WRITE && tensor_type != QNN_TENSOR_TYPE_APP_READWRITE) {
|
||||
QNN_LOG_WARN("tensor %s not writable", _tensor_name.c_str());
|
||||
return false;
|
||||
QNN_LOG_WARN("tensor %s type(%d) not WRITE", _tensor_name.c_str(), (int)tensor_type);
|
||||
}
|
||||
|
||||
if (should_use_mem_handle()) {
|
||||
if (_qnn_rpc_buffer) {
|
||||
memcpy(_qnn_rpc_buffer, _tensor->data, ggml_nbytes(_tensor));
|
||||
} else {
|
||||
QNN_LOG_WARN("can't find rpcmem from qnn mem handle\n");
|
||||
QNN_LOG_WARN("tensor %s: can't find rpcmem from qnn mem handle\n", _tensor_name.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -142,8 +135,7 @@ public:
|
|||
|
||||
auto tensor_type = QNN_TENSOR_GET_TYPE(_qnn_tensor);
|
||||
if (tensor_type != QNN_TENSOR_TYPE_APP_READ && tensor_type != QNN_TENSOR_TYPE_APP_READWRITE) {
|
||||
QNN_LOG_WARN("tensor %s not readable", _tensor_name.c_str());
|
||||
return false;
|
||||
QNN_LOG_WARN("tensor %s type(%d) not READ", _tensor_name.c_str(), (int)tensor_type);
|
||||
}
|
||||
|
||||
if (should_use_mem_handle()) {
|
||||
|
@ -190,6 +182,21 @@ private:
|
|||
|
||||
bool should_use_mem_handle() const { return _device == QNN_BACKEND_NPU; }
|
||||
|
||||
void update_tensor_name() {
|
||||
auto *tensor_name = ggml_get_name(_tensor);
|
||||
if (!strnlen(tensor_name, GGML_MAX_NAME)) {
|
||||
if (_tensor_name.empty()) {
|
||||
static std::atomic_uint32_t unnamed_tensor_count = 0;
|
||||
char buffer[GGML_MAX_NAME] = {};
|
||||
snprintf(buffer, sizeof(buffer), "unnamed_%d", (int)(unnamed_tensor_count++));
|
||||
_tensor_name = buffer;
|
||||
}
|
||||
} else {
|
||||
QNN_LOG_DEBUG("tensor name changed: %s -> %s", _tensor_name.c_str(), tensor_name);
|
||||
_tensor_name = tensor_name;
|
||||
}
|
||||
}
|
||||
|
||||
const ggml_tensor *_tensor;
|
||||
QNNBackend _device;
|
||||
std::shared_ptr<qnn_instance> _qnn_instance;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue