diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 40ac8efc3..0f18cff32 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -687,93 +687,89 @@ namespace dpct init_queues(); } - sycl::queue &in_order_queue() { return *_q_in_order; } + sycl::queue &in_order_queue() { return _q_in_order; } - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + sycl::queue &out_of_order_queue() { return _q_out_of_order; } sycl::queue &default_queue() { return in_order_queue(); } void queues_wait_and_throw() { std::unique_lock lock(m_mutex); lock.unlock(); - for (const auto &q : _queues) { - q->wait_and_throw(); + for (auto &q : _queues) { + q.wait_and_throw(); } // Guard the destruct of current_queues to make sure the ref count is // safe. lock.lock(); } - sycl::queue *create_queue(bool enable_exception_handler = false) { + sycl::queue create_queue(bool enable_exception_handler = false) { return create_in_order_queue(enable_exception_handler); } - sycl::queue *create_queue(sycl::device device, + sycl::queue create_queue(sycl::device device, bool enable_exception_handler = false) { return create_in_order_queue(device, enable_exception_handler); } - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { + sycl::queue create_in_order_queue(bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(enable_exception_handler, sycl::property::queue::in_order()); } - sycl::queue *create_in_order_queue(sycl::device device, + sycl::queue create_in_order_queue(sycl::device device, bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(device, enable_exception_handler, sycl::property::queue::in_order()); } - sycl::queue *create_out_of_order_queue( + sycl::queue create_out_of_order_queue( bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(enable_exception_handler); } - void destroy_queue(sycl::queue *&queue) { + void destroy_queue(sycl::queue queue) { std::lock_guard lock(m_mutex); _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr &q) -> bool + [=](const sycl::queue &q) -> bool { - return q.get() == queue; + return q == queue; }), _queues.end()); - queue = nullptr; } - void set_saved_queue(sycl::queue *q) { + void set_saved_queue(sycl::queue q) { std::lock_guard lock(m_mutex); _saved_queue = q; } - sycl::queue *get_saved_queue() const { + sycl::queue get_saved_queue() const { std::lock_guard lock(m_mutex); return _saved_queue; } private: - void clear_queues() { - _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; - } + void clear_queues() { _queues.clear(); } void init_queues() { _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); + _saved_queue = default_queue(); } /// Caller should acquire resource \p m_mutex before calling this /// function. template - sycl::queue *create_queue_impl(bool enable_exception_handler, + sycl::queue create_queue_impl(bool enable_exception_handler, Properties... properties) { sycl::async_handler eh = {}; if (enable_exception_handler) { eh = exception_handler; } - _queues.push_back(std::make_shared( + _queues.push_back(sycl::queue( *this, eh, sycl::property_list( #ifdef DPCT_PROFILING_ENABLED @@ -781,18 +777,18 @@ namespace dpct #endif properties...))); - return _queues.back().get(); + return _queues.back(); } template - sycl::queue *create_queue_impl(sycl::device device, + sycl::queue create_queue_impl(sycl::device device, bool enable_exception_handler, Properties... properties) { sycl::async_handler eh = {}; if (enable_exception_handler) { eh = exception_handler; } - _queues.push_back(std::make_shared( + _queues.push_back(sycl::queue( device, eh, sycl::property_list( #ifdef DPCT_PROFILING_ENABLED @@ -800,15 +796,15 @@ namespace dpct #endif properties...))); - return _queues.back().get(); + return _queues.back(); } void get_version(int &major, int &minor) const { detail::get_version(*this, major, minor); } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - std::vector> _queues; + sycl::queue _q_in_order, _q_out_of_order; + sycl::queue _saved_queue; + std::vector _queues; mutable mutex_type m_mutex; };