merge register_rpc_mem into alloc_rpc_mem

This commit is contained in:
hongruichen 2024-07-10 19:40:02 +08:00
parent e97d3a6c48
commit 3feb574bf0

View file

@ -2,6 +2,7 @@
#pragma once
#include <atomic>
#include <cstdint>
#include <memory>
#include <string>
@ -9,7 +10,6 @@
#include "QnnTensor.h"
#include "System/QnnSystemInterface.h"
#include "backend.hpp"
#include "graph.hpp"
#include "logger.hpp"
#include "qnn.hpp"
@ -88,12 +88,6 @@ public:
return false;
}
auto tensor_type = QNN_TENSOR_GET_TYPE(_qnn_tensor);
if (!register_rpc_mem(_qnn_rpc_buffer)) {
QNN_LOG_WARN("commit rpc mem failure\n");
return false;
}
QNN_LOG_DEBUG("tensor %s, use mem handle %p", _tensor_name.c_str(), QNN_TENSOR_GET_MEM_HANDLE(_qnn_tensor));
} else {
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_RAW);
@ -176,26 +170,18 @@ private:
}
QNN_LOG_INFO("tensor %s: alloc rpcmem(%p) successfully\n", _tensor_name.c_str(), qnn_rpc_buffer);
return qnn_rpc_buffer;
}
bool register_rpc_mem(uint8_t *qnn_rpc_buffer) {
if (_qnn_instance->is_rpcmem_registered(QNN_TENSOR_GET_MEM_HANDLE(_qnn_tensor))) {
QNN_LOG_INFO("tensor %s: rpcmem(%p) already registered\n", _tensor_name.c_str(), qnn_rpc_buffer);
return true;
}
auto error = _qnn_instance->register_rpcmem(qnn_rpc_buffer, &_qnn_tensor);
if (error != QNN_SUCCESS) {
QNN_LOG_WARN("register rpc mem failure, %d\n", (int)error);
QNN_LOG_DEBUG("tensor name %s", _tensor_name.c_str());
return false;
_qnn_instance->free_rpcmem(qnn_rpc_buffer);
return nullptr;
}
// The mem handle will be set at qnn_instance::register_rpcmem
QNN_TENSOR_SET_MEM_TYPE(_qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
QNN_LOG_INFO("tensor %s: register rpcmem(%p) successfully\n", _tensor_name.c_str(), qnn_rpc_buffer);
return true;
return qnn_rpc_buffer;
}
bool should_use_mem_handle() const { return _device == QNN_BACKEND_NPU; }