sycl::queue can directly use as shared_ptr

Signed-off-by: Chen Xi <xi2chen@intel.com>
This commit is contained in:
Chen Xi 2024-07-18 08:07:05 +00:00
parent bd71cdac0f
commit 6160a76efb

View file

@ -687,93 +687,89 @@ namespace dpct
init_queues(); init_queues();
} }
sycl::queue &in_order_queue() { return *_q_in_order; } sycl::queue &in_order_queue() { return _q_in_order; }
sycl::queue &out_of_order_queue() { return *_q_out_of_order; } sycl::queue &out_of_order_queue() { return _q_out_of_order; }
sycl::queue &default_queue() { return in_order_queue(); } sycl::queue &default_queue() { return in_order_queue(); }
void queues_wait_and_throw() { void queues_wait_and_throw() {
std::unique_lock<mutex_type> lock(m_mutex); std::unique_lock<mutex_type> lock(m_mutex);
lock.unlock(); lock.unlock();
for (const auto &q : _queues) { for (auto &q : _queues) {
q->wait_and_throw(); q.wait_and_throw();
} }
// Guard the destruct of current_queues to make sure the ref count is // Guard the destruct of current_queues to make sure the ref count is
// safe. // safe.
lock.lock(); lock.lock();
} }
sycl::queue *create_queue(bool enable_exception_handler = false) { sycl::queue create_queue(bool enable_exception_handler = false) {
return create_in_order_queue(enable_exception_handler); return create_in_order_queue(enable_exception_handler);
} }
sycl::queue *create_queue(sycl::device device, sycl::queue create_queue(sycl::device device,
bool enable_exception_handler = false) { bool enable_exception_handler = false) {
return create_in_order_queue(device, enable_exception_handler); return create_in_order_queue(device, enable_exception_handler);
} }
sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(enable_exception_handler, return create_queue_impl(enable_exception_handler,
sycl::property::queue::in_order()); sycl::property::queue::in_order());
} }
sycl::queue *create_in_order_queue(sycl::device device, sycl::queue create_in_order_queue(sycl::device device,
bool enable_exception_handler = false) { bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(device, enable_exception_handler, return create_queue_impl(device, enable_exception_handler,
sycl::property::queue::in_order()); sycl::property::queue::in_order());
} }
sycl::queue *create_out_of_order_queue( sycl::queue create_out_of_order_queue(
bool enable_exception_handler = false) { bool enable_exception_handler = false) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return create_queue_impl(enable_exception_handler); return create_queue_impl(enable_exception_handler);
} }
void destroy_queue(sycl::queue *&queue) { void destroy_queue(sycl::queue queue) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
_queues.erase(std::remove_if(_queues.begin(), _queues.end(), _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
[=](const std::shared_ptr<sycl::queue> &q) -> bool [=](const sycl::queue &q) -> bool
{ {
return q.get() == queue; return q == queue;
}), }),
_queues.end()); _queues.end());
queue = nullptr;
} }
void set_saved_queue(sycl::queue *q) { void set_saved_queue(sycl::queue q) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
_saved_queue = q; _saved_queue = q;
} }
sycl::queue *get_saved_queue() const { sycl::queue get_saved_queue() const {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
return _saved_queue; return _saved_queue;
} }
private: private:
void clear_queues() { void clear_queues() { _queues.clear(); }
_queues.clear();
_q_in_order = _q_out_of_order = _saved_queue = nullptr;
}
void init_queues() { void init_queues() {
_q_in_order = _q_in_order =
create_queue_impl(true, sycl::property::queue::in_order()); create_queue_impl(true, sycl::property::queue::in_order());
_q_out_of_order = create_queue_impl(true); _q_out_of_order = create_queue_impl(true);
_saved_queue = &default_queue(); _saved_queue = default_queue();
} }
/// Caller should acquire resource \p m_mutex before calling this /// Caller should acquire resource \p m_mutex before calling this
/// function. /// function.
template <class... Properties> template <class... Properties>
sycl::queue *create_queue_impl(bool enable_exception_handler, sycl::queue create_queue_impl(bool enable_exception_handler,
Properties... properties) { Properties... properties) {
sycl::async_handler eh = {}; sycl::async_handler eh = {};
if (enable_exception_handler) { if (enable_exception_handler) {
eh = exception_handler; eh = exception_handler;
} }
_queues.push_back(std::make_shared<sycl::queue>( _queues.push_back(sycl::queue(
*this, eh, *this, eh,
sycl::property_list( sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
@ -781,18 +777,18 @@ namespace dpct
#endif #endif
properties...))); properties...)));
return _queues.back().get(); return _queues.back();
} }
template <class... Properties> template <class... Properties>
sycl::queue *create_queue_impl(sycl::device device, sycl::queue create_queue_impl(sycl::device device,
bool enable_exception_handler, bool enable_exception_handler,
Properties... properties) { Properties... properties) {
sycl::async_handler eh = {}; sycl::async_handler eh = {};
if (enable_exception_handler) { if (enable_exception_handler) {
eh = exception_handler; eh = exception_handler;
} }
_queues.push_back(std::make_shared<sycl::queue>( _queues.push_back(sycl::queue(
device, eh, device, eh,
sycl::property_list( sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
@ -800,15 +796,15 @@ namespace dpct
#endif #endif
properties...))); properties...)));
return _queues.back().get(); return _queues.back();
} }
void get_version(int &major, int &minor) const { void get_version(int &major, int &minor) const {
detail::get_version(*this, major, minor); detail::get_version(*this, major, minor);
} }
sycl::queue *_q_in_order, *_q_out_of_order; sycl::queue _q_in_order, _q_out_of_order;
sycl::queue *_saved_queue; sycl::queue _saved_queue;
std::vector<std::shared_ptr<sycl::queue>> _queues; std::vector<sycl::queue> _queues;
mutable mutex_type m_mutex; mutable mutex_type m_mutex;
}; };