diff --git a/ggml/src/ggml-qnn/qnn-lib.hpp b/ggml/src/ggml-qnn/qnn-lib.hpp index aa142c74a..da986e2e4 100644 --- a/ggml/src/ggml-qnn/qnn-lib.hpp +++ b/ggml/src/ggml-qnn/qnn-lib.hpp @@ -591,8 +591,6 @@ public: size_t get_rpcmem_capacity() { return _rpcmem_capacity; } - bool is_rpcmem_registered(Qnn_MemHandle_t handle) { return _qnn_mem_set.count(handle) != 0U; } - void *alloc_rpcmem(size_t bytes, size_t alignment) { if (!_rpcmem_initialized) { QNN_LOG_WARN("rpc memory not initialized\n"); @@ -619,7 +617,7 @@ public: void free_rpcmem(void *buf) { if (!_rpcmem_initialized) { QNN_LOG_WARN("rpc memory not initialized\n"); - } else if (0 == _rpcmem_store_map.count(buf)) { + } else if (_rpcmem_store_map.count(buf) == 0) { QNN_LOG_WARN("no allocated tensor\n"); } else { _pfn_rpc_mem_free(_rpcmem_store_map[buf]); @@ -638,18 +636,6 @@ public: return mem_fd; } - void *get_rpcmem_from_memhandle(Qnn_MemHandle_t mem_handle) { - for (std::unordered_map::iterator it = _qnn_mem_set.begin(); it != _qnn_mem_set.end(); - it++) { - Qnn_MemHandle_t mem_handle = it->second; - if (it->second == mem_handle) { - return it->first; - } - } - QNN_LOG_WARN("can't find rpcmem from qnn mem handle %p", mem_handle); - return nullptr; - } - Qnn_MemHandle_t register_rpcmem(void *p_data, uint32_t rank, uint32_t *dimensions, Qnn_DataType_t data_type) { if (!p_data) { QNN_LOG_WARN("invalid param\n"); @@ -661,9 +647,9 @@ public: return nullptr; } - if (is_rpcmem_allocated(p_data)) { - QNN_LOG_WARN("rpc memory already allocated\n"); - return nullptr; + if (is_rpcmem_registered(p_data)) { + QNN_LOG_WARN("rpc memory already registered\n"); + return _qnn_rpc_buffer_to_handles[p_data]; } auto mem_fd = rpcmem_to_fd(p_data); @@ -683,8 +669,7 @@ public: return nullptr; } - _qnn_mem_set.insert((std::pair(p_data, handle))); - + _qnn_rpc_buffer_to_handles.insert({ p_data, handle }); QNN_LOG_INFO("successfully register shared memory handler: %p\n", handle); return handle; } @@ -695,14 +680,18 @@ public: QNN_LOG_WARN("failed to unregister shared memory, error %d\n", QNN_GET_ERROR_CODE(error)); } - auto it = std::find_if(_qnn_mem_set.begin(), _qnn_mem_set.end(), + auto it = std::find_if(_qnn_rpc_buffer_to_handles.begin(), _qnn_rpc_buffer_to_handles.end(), [mem_handle](const auto &kv) { return kv.second == mem_handle; }); - if (it != _qnn_mem_set.end()) { - _qnn_mem_set.erase(it); + if (it == _qnn_rpc_buffer_to_handles.end()) { + QNN_LOG_WARN("failed to find shared memory handler: %p\n", mem_handle); + return; } + + _qnn_rpc_buffer_to_handles.erase(it); } - bool is_rpcmem_allocated(void *buf) { return _qnn_mem_set.count(buf) != 0U; } + bool is_rpcmem_allocated(void *buf) { return _rpcmem_store_map.count(buf) != 0; } + bool is_rpcmem_registered(void *buf) { return _qnn_rpc_buffer_to_handles.count(buf) != 0U; } const qnn::qcom_socinfo &get_soc_info() { return _soc_info; } @@ -892,7 +881,7 @@ private: QnnHtpDevice_PerfInfrastructure_t *_qnn_htp_perfinfra = nullptr; uint32_t _qnn_power_configid = 1; - std::unordered_map _qnn_mem_set; + std::unordered_map _qnn_rpc_buffer_to_handles; std::mutex _init_mutex; std::unordered_map _loaded_lib_handle;