From a47f5ec42e3e78b88be549d3b0146f690fbd1758 Mon Sep 17 00:00:00 2001 From: jianyuzh Date: Sat, 13 Jan 2024 20:33:42 +0800 Subject: [PATCH] summary dpct definition in one header file to replace folder:dpct --- dpct.hpp | 2831 ++++++++++++++++ dpct/atomic.hpp | 842 ----- dpct/blas_utils.hpp | 1792 ---------- dpct/ccl_utils.hpp | 286 -- dpct/device.hpp | 781 ----- dpct/dnnl_utils.hpp | 4921 ---------------------------- dpct/dpct.hpp | 62 - dpct/dpl_extras/algorithm.h | 2419 -------------- dpct/dpl_extras/dpcpp_extensions.h | 747 ----- dpct/dpl_extras/functional.h | 453 --- dpct/dpl_extras/iterators.h | 347 -- dpct/dpl_extras/memory.h | 1024 ------ dpct/dpl_extras/numeric.h | 32 - dpct/dpl_extras/vector.h | 752 ----- dpct/dpl_utils.hpp | 26 - dpct/fft_utils.hpp | 1376 -------- dpct/image.hpp | 901 ----- dpct/kernel.hpp | 459 --- dpct/lapack_utils.hpp | 1953 ----------- dpct/lib_common_utils.hpp | 174 - dpct/math.hpp | 1814 ---------- dpct/memory.hpp | 1497 --------- dpct/rng_utils.hpp | 535 --- dpct/sparse_utils.hpp | 1385 -------- dpct/util.hpp | 1070 ------ ggml-sycl.cpp | 7 +- run.sh | 2 +- 27 files changed, 2836 insertions(+), 25652 deletions(-) create mode 100644 dpct.hpp delete mode 100644 dpct/atomic.hpp delete mode 100644 dpct/blas_utils.hpp delete mode 100644 dpct/ccl_utils.hpp delete mode 100644 dpct/device.hpp delete mode 100644 dpct/dnnl_utils.hpp delete mode 100644 dpct/dpct.hpp delete mode 100644 dpct/dpl_extras/algorithm.h delete mode 100644 dpct/dpl_extras/dpcpp_extensions.h delete mode 100644 dpct/dpl_extras/functional.h delete mode 100644 dpct/dpl_extras/iterators.h delete mode 100644 dpct/dpl_extras/memory.h delete mode 100644 dpct/dpl_extras/numeric.h delete mode 100644 dpct/dpl_extras/vector.h delete mode 100644 dpct/dpl_utils.hpp delete mode 100644 dpct/fft_utils.hpp delete mode 100644 dpct/image.hpp delete mode 100644 dpct/kernel.hpp delete mode 100644 dpct/lapack_utils.hpp delete mode 100644 dpct/lib_common_utils.hpp delete mode 100644 dpct/math.hpp delete mode 100644 dpct/memory.hpp delete mode 100644 dpct/rng_utils.hpp delete mode 100644 dpct/sparse_utils.hpp delete mode 100644 dpct/util.hpp diff --git a/dpct.hpp b/dpct.hpp new file mode 100644 index 000000000..874fa1309 --- /dev/null +++ b/dpct.hpp @@ -0,0 +1,2831 @@ +// COPY from DPCT head files +// To clear the code, copy/paste the variable/macro/function from following files. +// It' possible to get better performance from newer function version DPCT head files. +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#if defined(__linux__) +#include +#elif defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#error "Only support Windows and Linux." +#endif + +#if defined(__linux__) +#include +#include +#endif +#if defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#define DPCT_COMPATIBILITY_TEMP (900) + +#if defined(_MSC_VER) +#define __dpct_align__(n) __declspec(align(n)) +#define __dpct_inline__ __forceinline +#else +#define __dpct_align__(n) __attribute__((aligned(n))) +#define __dpct_inline__ __inline__ __attribute__((always_inline)) +#endif + +#if defined(_MSC_VER) +#define __dpct_noinline__ __declspec(noinline) +#else +#define __dpct_noinline__ __attribute__((noinline)) +#endif + +namespace dpct +{ + typedef sycl::queue *queue_ptr; + typedef sycl::event *event_ptr; + typedef char *device_ptr; + typedef uint8_t byte_t; + typedef sycl::buffer buffer_t; + + /// SYCL default exception handler + inline auto exception_handler = [](sycl::exception_list exceptions) + { + for (std::exception_ptr const &e : exceptions) + { + try + { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) + { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } + }; + + enum error_code + { + success = 0, + default_error = 999 + }; + + enum memcpy_direction + { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic + }; + + enum memory_region + { + global = 0, // device global memory + constant, // device constant memory + local, // device local memory + shared, // memory which can be accessed by host and device + }; + + enum class library_data_t : unsigned char + { + real_float = 0, + complex_float, + real_double, + complex_double, + real_half, + complex_half, + real_bfloat16, + complex_bfloat16, + real_int4, + complex_int4, + real_uint4, + complex_uint4, + real_int8, + complex_int8, + real_uint8, + complex_uint8, + real_int16, + complex_int16, + real_uint16, + complex_uint16, + real_int32, + complex_int32, + real_uint32, + complex_uint32, + real_int64, + complex_int64, + real_uint64, + complex_uint64, + real_int8_4, + real_int8_32, + real_uint8_4, + library_data_t_size + }; + + template + struct DataType + { + using T2 = T; + }; + template + struct DataType> + { + using T2 = std::complex; + }; + + static void destroy_event(event_ptr event) + { + delete event; + } + + static inline unsigned int get_tid() + { +#if defined(__linux__) + return syscall(SYS_gettid); +#elif defined(_WIN64) + return GetCurrentThreadId(); +#else +#error "Only support Windows and Linux." +#endif + } + + namespace detail + { + static void get_version(const sycl::device &dev, int &major, int &minor) + { + // Version string has the following format: + // a. OpenCL + // b. + std::string ver; + ver = dev.get_info(); + std::string::size_type i = 0; + while (i < ver.size()) + { + if (isdigit(ver[i])) + break; + i++; + } + major = std::stoi(&(ver[i])); + while (i < ver.size()) + { + if (ver[i] == '.') + break; + i++; + } + i++; + minor = std::stoi(&(ver[i])); + } + + template + class generic_error_type + { + public: + generic_error_type() = default; + generic_error_type(T value) : value{value} {} + operator T() const { return value; } + + private: + T value; + }; + + } // namespace detail + + /// Pitched 2D/3D memory data. + class pitched_data + { + public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void *data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void *get_data_ptr() { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_pitch() { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + size_t get_x() { return _x; } + void set_x(size_t x) { _x = x; }; + + size_t get_y() { return _y; } + void set_y(size_t y) { _y = y; } + + private: + void *_data; + size_t _pitch, _x, _y; + }; + + class device_info + { + public: + // get interface + const char *get_name() const { return _name; } + char *get_name() { return _name; } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + bool get_host_unified_memory() const { return _host_unified_memory; } + int get_major_version() const { return _major; } + int get_minor_version() const { return _minor; } + int get_integrated() const { return _integrated; } + int get_max_clock_frequency() const { return _frequency; } + int get_max_compute_units() const { return _max_compute_units; } + int get_max_work_group_size() const { return _max_work_group_size; } + int get_max_sub_group_size() const { return _max_sub_group_size; } + int get_max_work_items_per_compute_unit() const + { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const + { + return _max_register_size_per_work_group; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { return _global_mem_size; } + size_t get_local_mem_size() const { return _local_mem_size; } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { return _memory_bus_width; } + uint32_t get_device_id() const { return _device_id; } + std::array get_uuid() const { return _uuid; } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const + { + return _global_mem_cache_size; + } + + // set interface + void set_name(const char *name) + { + size_t length = strlen(name); + if (length < 256) + { + std::memcpy(_name, name, length + 1); + } + else + { + std::memcpy(_name, name, 255); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + { + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) + { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { _major = major; } + void set_minor_version(int minor) { _minor = minor; } + void set_integrated(int integrated) { _integrated = integrated; } + void set_max_clock_frequency(int frequency) { _frequency = frequency; } + void set_max_compute_units(int max_compute_units) + { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) + { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) + { + _local_mem_size = local_mem_size; + } + void set_max_work_group_size(int max_work_group_size) + { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) + { + _max_sub_group_size = max_sub_group_size; + } + void + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) + { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) + { + for (int i = 0; i < 3; i++) + { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) + { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) + { + _memory_bus_width = memory_bus_width; + } + void + set_max_register_size_per_work_group(int max_register_size_per_work_group) + { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) + { + _device_id = device_id; + } + void set_uuid(std::array uuid) + { + _uuid = std::move(uuid); + } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) + { + _global_mem_cache_size = global_mem_cache_size; + } + + private: + char _name[256]; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; + }; + + static int get_major_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return major; + } + + static int get_minor_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return minor; + } + + static void get_device_info(device_info &out, const sycl::device &dev) + { + device_info prop; + prop.set_name(dev.get_info().c_str()); + + int major, minor; + detail::get_version(dev, major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); + + prop.set_max_work_item_sizes( +#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) + // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes + // is an enum class element + dev.get_info()); +#else + // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by + // an int + dev.get_info>()); +#endif + prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); + + prop.set_max_clock_frequency( + dev.get_info() * 1000); + + prop.set_max_compute_units( + dev.get_info()); + prop.set_max_work_group_size( + dev.get_info()); + prop.set_global_mem_size(dev.get_info()); + prop.set_local_mem_size(dev.get_info()); + +#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) + if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) + { + unsigned int tmp = + dev.get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) + { + prop.set_memory_bus_width( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_id)) + { + prop.set_device_id( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) + { + prop.set_uuid(dev.get_info()); + } +#elif defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value.") +#else +#warning "get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value." +#endif + + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + dev.get_info(); + + for (const auto &sub_group_size : sub_group_sizes) + { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } + + prop.set_max_sub_group_size(max_sub_group_size); + + prop.set_max_work_items_per_compute_unit( + dev.get_info()); + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); + + // Estimates max register size per work group, feel free to update the value + // according to device properties. + prop.set_max_register_size_per_work_group(65536); + + prop.set_global_mem_cache_size( + dev.get_info()); + out = prop; + } + + /// dpct device extension + class device_ext : public sycl::device + { + typedef std::mutex mutex_type; + + public: + device_ext() : sycl::device(), _ctx(*this) {} + ~device_ext() + { + std::lock_guard lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) + { + std::lock_guard lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { return 0; } + int get_major_version() const + { + return dpct::get_major_version(*this); + } + + int get_minor_version() const + { + return dpct::get_minor_version(*this); + } + + int get_max_compute_units() const + { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const + { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const + { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const + { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const + { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const + { + return get_info(); + } + + size_t get_global_mem_size() const + { + return get_device_info().get_global_mem_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the SYCL device. + /// \param [out] total_memory The number of bytes of total memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) + { +#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) + if (!has(sycl::aspect::ext_intel_free_memory)) + { + std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; + free_memory = 0; + } + else + { + free_memory = get_info(); + } +#else + std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; + free_memory = 0; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma message("Querying the number of bytes of free memory is not supported") +#else +#warning "Querying the number of bytes of free memory is not supported" +#endif +#endif + total_memory = get_device_info().get_global_mem_size(); + } + + void get_device_info(device_info &out) const + { + dpct::get_device_info(out, *this); + } + + device_info get_device_info() const + { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } + + void reset() + { + std::lock_guard lock(m_mutex); + clear_queues(); + init_queues(); + } + + sycl::queue &in_order_queue() { return *_q_in_order; } + + sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + + sycl::queue &default_queue() + { +#ifdef DPCT_USM_LEVEL_NONE + return out_of_order_queue(); +#else + return in_order_queue(); +#endif // DPCT_USM_LEVEL_NONE + } + + void queues_wait_and_throw() + { + std::unique_lock lock(m_mutex); + std::vector> current_queues( + _queues); + lock.unlock(); + for (const auto &q : current_queues) + { + q->wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is safe. + lock.lock(); + } + + sycl::queue *create_queue(bool enable_exception_handler = false) + { +#ifdef DPCT_USM_LEVEL_NONE + return create_out_of_order_queue(enable_exception_handler); +#else + return create_in_order_queue(enable_exception_handler); +#endif // DPCT_USM_LEVEL_NONE + } + + sycl::queue *create_in_order_queue(bool enable_exception_handler = false) + { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) + { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue *&queue) + { + std::lock_guard lock(m_mutex); + _queues.erase(std::remove_if(_queues.begin(), _queues.end(), + [=](const std::shared_ptr &q) -> bool + { + return q.get() == queue; + }), + _queues.end()); + queue = nullptr; + } + void set_saved_queue(sycl::queue *q) + { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + sycl::queue *get_saved_queue() const + { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + sycl::context get_context() const { return _ctx; } + + private: + void clear_queues() + { + _queues.clear(); + _q_in_order = _q_out_of_order = _saved_queue = nullptr; + } + + void init_queues() + { + _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = &default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this function. + template + sycl::queue *create_queue_impl(bool enable_exception_handler, + Properties... properties) + { + sycl::async_handler eh = {}; + if (enable_exception_handler) + { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + _ctx, *this, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back().get(); + } + + void get_version(int &major, int &minor) const + { + detail::get_version(*this, major, minor); + } + sycl::queue *_q_in_order, *_q_out_of_order; + sycl::queue *_saved_queue; + sycl::context _ctx; + std::vector> _queues; + mutable mutex_type m_mutex; + }; + + /// device manager + class dev_mgr + { + public: + device_ext ¤t_device() + { + unsigned int dev_id = current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext &cpu_device() const + { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) + { + throw std::runtime_error("no valid cpu device"); + } + else + { + return *_devs[_cpu_device]; + } + } + device_ext &get_device(unsigned int id) const + { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const + { + std::lock_guard lock(m_mutex); + auto it = _thread2dev_map.find(get_tid()); + if (it != _thread2dev_map.end()) + return it->second; + return DEFAULT_DEVICE_ID; + } + + /// Select device with a device ID. + /// \param [in] id The id of the device which can + /// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) + { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()] = id; + } + unsigned int device_count() { return _devs.size(); } + + unsigned int get_device_id(const sycl::device &dev) + { + unsigned int id = 0; + for (auto dev_item : _devs) + { + if (*dev_item == dev) + { + break; + } + id++; + } + return id; + } + + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector &selector = sycl::gpu_selector_v) + { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr &instance() + { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr &) = delete; + dev_mgr &operator=(const dev_mgr &) = delete; + dev_mgr(dev_mgr &&) = delete; + dev_mgr &operator=(dev_mgr &&) = delete; + + private: + mutable std::recursive_mutex m_mutex; + dev_mgr() + { + sycl::device default_device = + sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs = + sycl::device::get_devices(sycl::info::device_type::all); + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + for (auto &dev : sycl_all_devs) + { + if (dev == default_device) + { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) + { + _cpu_device = _devs.size() - 1; + } + } + } + void check_id(unsigned int id) const + { + if (id >= _devs.size()) + { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int DEFAULT_DEVICE_ID = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; + }; + + static inline sycl::queue &get_default_queue() + { + return dev_mgr::instance().current_device().default_queue(); + } + + namespace detail + { + enum class pointer_access_attribute + { + host_only = 0, + device_only, + host_device, + end + }; + + static pointer_access_attribute get_pointer_attribute(sycl::queue &q, + const void *ptr) + { +#ifdef DPCT_USM_LEVEL_NONE + return mem_mgr::instance().is_device_ptr(ptr) + ? pointer_access_attribute::device_only + : pointer_access_attribute::host_only; +#else + switch (sycl::get_pointer_type(ptr, q.get_context())) + { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } +#endif + } + + template + inline constexpr std::uint64_t get_type_combination_id(ArgT Val) + { + static_assert((unsigned char)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(std::is_same_v, "Unsupported ArgT"); + return (std::uint64_t)Val; + } + + template + inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, + RestT... RestVal) + { + static_assert((std::uint8_t)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); + static_assert(std::is_same_v, "Unsupported FirstT"); + return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); + } + + class mem_mgr + { + mem_mgr() + { + // Reserved address space, no real memory allocation happens here. +#if defined(__linux__) + mapped_address_space = + (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#elif defined(_WIN64) + mapped_address_space = (byte_t *)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access +#else +#error "Only support Windows and Linux." +#endif + next_free = mapped_address_space; + }; + + public: + using buffer_id_t = int; + + struct allocation + { + buffer_t buffer; + byte_t *alloc_ptr; + size_t size; + }; + + ~mem_mgr() + { +#if defined(__linux__) + munmap(mapped_address_space, mapped_region_size); +#elif defined(_WIN64) + VirtualFree(mapped_address_space, 0, MEM_RELEASE); +#else +#error "Only support Windows and Linux." +#endif + }; + + mem_mgr(const mem_mgr &) = delete; + mem_mgr &operator=(const mem_mgr &) = delete; + mem_mgr(mem_mgr &&) = delete; + mem_mgr &operator=(mem_mgr &&) = delete; + + /// Allocate + void *mem_alloc(size_t size) + { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) + { + throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> r(size); + buffer_t buf(r); + allocation A{buf, next_free, size}; + // Map allocation to device pointer + void *result = next_free; + m_map.emplace(next_free + size, A); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void *ptr) + { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void *ptr) + { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void *ptr) const + { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr &instance() + { + static mem_mgr m; + return m; + } + + private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t *mapped_address_space; + byte_t *next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void *ptr) + { + auto it = m_map.upper_bound((byte_t *)ptr); + if (it == m_map.end()) + { + // Not a virtual pointer. + throw std::runtime_error("can not get buffer from non-virtual pointer"); + } + const allocation &alloc = it->second; + if (ptr < alloc.alloc_ptr) + { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("invalid virtual pointer"); + } + return it; + } + }; + + template + class accessor; + template + class memory_traits + { + public: + static constexpr sycl::access::target target = + sycl::access::target::device; + static constexpr sycl::access_mode mode = + (Memory == constant) ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional::type; + using value_t = typename std::remove_cv::type; + template + using accessor_t = typename std::conditional< + Memory == local, sycl::local_accessor, + sycl::accessor>::type; + using pointer_t = T *; + }; + + static inline void *dpct_malloc(size_t size, sycl::queue &q) + { +#ifdef DPCT_USM_LEVEL_NONE + return mem_mgr::instance().mem_alloc(size * sizeof(byte_t)); +#else + return sycl::malloc_device(size, q.get_device(), q.get_context()); +#endif // DPCT_USM_LEVEL_NONE + } + +#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) + static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, + sycl::queue &q) + { + pitch = PITCH_DEFAULT_ALIGN(x); + return dpct_malloc(pitch * y * z, q); + } + + /** + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ + template + static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, + valueT value, size_t size) + { +#ifdef DPCT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + assert(mm.is_device_ptr(dev_ptr)); + auto alloc = mm.translate_ptr(dev_ptr); + size_t offset = (valueT *)dev_ptr - (valueT *)alloc.alloc_ptr; + + return q.submit([&](sycl::handler &cgh) + { + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + auto new_buffer = alloc.buffer.reinterpret( + sycl::range<1>(alloc.size / sizeof(valueT))); + sycl::accessor + acc(new_buffer, cgh, r, o); + cgh.fill(acc, value); }); +#else + return q.fill(dev_ptr, value, size); +#endif // DPCT_USM_LEVEL_NONE + } + + /** + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, pitched_data data, valueT value, + sycl::range<3> size) + { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; + } + + /** + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, + size_t y) + { + return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, + sycl::range<3>(x, y, 1)); + } + + static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, + const void *from_ptr, + memcpy_direction dir) + { + switch (dir) + { + case memcpy_direction::host_to_host: + case memcpy_direction::host_to_device: + case memcpy_direction::device_to_host: + case memcpy_direction::device_to_device: + return dir; + case memcpy_direction::automatic: + { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast(pointer_access_attribute::end)] = + {{memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; +#ifdef DPCT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + + switch (real_direction) + { + case host_to_host: + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); }); + case host_to_device: + { + auto alloc = mm.translate_ptr(to_ptr); + size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(from_ptr, acc); }); + } + case device_to_host: + { + auto alloc = mm.translate_ptr(from_ptr); + size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(acc, to_ptr); }); + } + case device_to_device: + { + auto to_alloc = mm.translate_ptr(to_ptr); + auto from_alloc = mm.translate_ptr(from_ptr); + size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, r, to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, r, from_o); + cgh.copy(from_acc, to_acc); }); + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } +#else + return q.memcpy(to_ptr, from_ptr, size, dep_events); +#endif // DPCT_USM_LEVEL_NONE + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: +#ifdef DPCT_USM_LEVEL_NONE + { + auto &mm = mem_mgr::instance(); + auto to_alloc = mm.translate_ptr(to_surface); + auto from_alloc = mm.translate_ptr(from_surface); + size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr; + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, + get_copy_range(size, to_slice, to_range.get(0)), to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, + get_copy_range(size, from_slice, from_range.get(0)), from_o); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_acc[get_offset(id, to_slice, to_range.get(0))] = + from_acc[get_offset(id, from_slice, from_range.get(0))]; + }); })); + } +#else + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); +#endif + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + namespace deprecated + { + + template + class usm_allocator + { + private: + using Alloc = sycl::usm_allocator; + Alloc _impl; + + public: + using value_type = typename std::allocator_traits::value_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using void_pointer = typename std::allocator_traits::void_pointer; + using const_void_pointer = + typename std::allocator_traits::const_void_pointer; + using reference = typename std::allocator_traits::value_type &; + using const_reference = + const typename std::allocator_traits::value_type &; + using difference_type = + typename std::allocator_traits::difference_type; + using size_type = typename std::allocator_traits::size_type; + using propagate_on_container_copy_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_copy_assignment; + using propagate_on_container_move_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_move_assignment; + using propagate_on_container_swap = + typename std::allocator_traits::propagate_on_container_swap; + using is_always_equal = + typename std::allocator_traits::is_always_equal; + + template + struct rebind + { + typedef usm_allocator other; + }; + + usm_allocator() : _impl(dpct::get_default_queue()) {} + ~usm_allocator() {} + usm_allocator(const usm_allocator &other) : _impl(other._impl) {} + usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} + pointer address(reference r) { return &r; } + const_pointer address(const_reference r) { return &r; } + pointer allocate(size_type cnt, const_void_pointer hint = nullptr) + { + return std::allocator_traits::allocate(_impl, cnt, hint); + } + void deallocate(pointer p, size_type cnt) + { + std::allocator_traits::deallocate(_impl, p, cnt); + } + size_type max_size() const + { + return std::allocator_traits::max_size(_impl); + } + bool operator==(const usm_allocator &other) const { return _impl == other._impl; } + bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } + }; + + } // namespace deprecated + + inline void dpct_free(void *ptr, + const sycl::queue &q) + { + if (ptr) + { +#ifdef DPCT_USM_LEVEL_NONE + detail::mem_mgr::instance().mem_free(ptr); +#else + sycl::free(ptr, q.get_context()); +#endif // DPCT_USM_LEVEL_NONE + } + } + + template + inline auto get_memory(const void *x) + { + T *new_x = reinterpret_cast(const_cast(x)); +#ifdef DPCT_USM_LEVEL_NONE + return dpct::get_buffer>(new_x); +#else + return new_x; +#endif + } + + template + inline typename DataType::T2 get_value(const T *s, sycl::queue &q) + { + using Ty = typename DataType::T2; + Ty s_h; + if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) + detail::dpct_memcpy(q, (void *)&s_h, (void *)s, sizeof(T), device_to_host) + .wait(); + else + s_h = *reinterpret_cast(s); + return s_h; + } + + } // namespace detail + + template + inline auto get_value(const T *s, sycl::queue &q) + { + return detail::get_value(s, q); + } + + namespace detail + { + template + inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, + int ldb, const void *beta, void *c, int ldc) + { +#ifndef __INTEL_MKL__ + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + data_b, ldb, beta_value, data_c, ldc); +#endif + } + + template + class vectorized_binary + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) + { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } + }; + + template + class vectorized_binary< + VecT, BinaryOperation, + std::void_t>> + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + return binary_op(a, b).template as(); + } + }; + + template + inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void **a, int lda, + const void **b, int ldb, const void *beta, void **c, + int ldc, int batch_size) + { + struct matrix_info_t + { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; + }; + + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + + matrix_info_t *matrix_info = + (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); + matrix_info->transpose_info[0] = a_trans; + matrix_info->transpose_info[1] = b_trans; + matrix_info->value_info[0] = alpha_value; + matrix_info->value_info[1] = beta_value; + matrix_info->size_info[0] = m; + matrix_info->size_info[1] = n; + matrix_info->size_info[2] = k; + matrix_info->ld_info[0] = lda; + matrix_info->ld_info[1] = ldb; + matrix_info->ld_info[2] = ldc; + matrix_info->groupsize_info = batch_size; + + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, + matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, matrix_info->value_info, + reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), + matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + + q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(e); + cgh.host_task([=] { std::free(matrix_info); }); }); + } + + template + inline void + gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, + int k, const void *alpha, const void *a, int lda, + long long int stride_a, const void *b, int ldb, + long long int stride_b, const void *beta, void *c, + int ldc, long long int stride_c, int batch_size) + { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm_batch( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, + data_c, ldc, stride_c, batch_size); + } + + } // namespace detail + + template + inline unsigned vectorized_binary(unsigned a, unsigned b, + const BinaryOperation binary_op) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + v0 = v4.template as>(); + return v0; + } + + static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); + } + + static inline unsigned int select_device(unsigned int id) + { + dev_mgr::instance().select_device(id); + return id; + } + + template + T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, + int logical_sub_group_size = 32) + { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + return sycl::select_from_group(g, x, + target_offset < logical_sub_group_size + ? start_index + target_offset + : id); + } + + template + sycl::vec extract_and_sign_or_zero_extend4(T val) + { + return sycl::vec(val) + .template as, int8_t, uint8_t>, 4>>() + .template convert(); + } + + template + using dot_product_acc_t = + std::conditional_t && std::is_unsigned_v, + uint32_t, int32_t>; + + template + inline auto dp4a(T1 a, T2 b, T3 c) + { + dot_product_acc_t res = c; + auto va = extract_and_sign_or_zero_extend4(a); + auto vb = extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[0]; + res += va[1] * vb[1]; + res += va[2] * vb[2]; + res += va[3] * vb[3]; + return res; + } + + struct sub_sat + { + template + auto operator()(const T x, const T y) const + { + return sycl::sub_sat(x, y); + } + }; + + template + inline T vectorized_min(T a, T b) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::min(v2, v3); + v0 = v4.template as>(); + return v0; + } + + inline float pow(const float a, const int b) { return sycl::pown(a, b); } + inline double pow(const double a, const int b) { return sycl::pown(a, b); } + inline float pow(const float a, const float b) { return sycl::pow(a, b); } + inline double pow(const double a, const double b) { return sycl::pow(a, b); } + template + inline typename std::enable_if_t, T> + pow(const T a, const U b) + { + return sycl::pow(a, static_cast(b)); + } + template + inline typename std::enable_if_t, double> + pow(const T a, const U b) + { + return sycl::pow(static_cast(a), static_cast(b)); + } + + inline double min(const double a, const float b) + { + return sycl::fmin(a, static_cast(b)); + } + inline double min(const float a, const double b) + { + return sycl::fmin(static_cast(a), b); + } + inline float min(const float a, const float b) { return sycl::fmin(a, b); } + inline double min(const double a, const double b) { return sycl::fmin(a, b); } + inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int32_t min(const std::int32_t a, const std::int32_t b) + { + return sycl::min(a, b); + } + inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int64_t min(const std::int64_t a, const std::int64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + // max function overloads. + // For floating-point types, `float` or `double` arguments are acceptable. + // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or + // `std::int64_t` type arguments are acceptable. + inline double max(const double a, const float b) + { + return sycl::fmax(a, static_cast(b)); + } + inline double max(const float a, const double b) + { + return sycl::fmax(static_cast(a), b); + } + inline float max(const float a, const float b) { return sycl::fmax(a, b); } + inline double max(const double a, const double b) { return sycl::fmax(a, b); } + inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int32_t max(const std::int32_t a, const std::int32_t b) + { + return sycl::max(a, b); + } + inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int64_t max(const std::int64_t a, const std::int64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + + inline void + has_capability_or_fail(const sycl::device &dev, + const std::initializer_list &props) + { + for (const auto &it : props) + { + if (dev.has(it)) + continue; + switch (it) + { + case sycl::aspect::fp64: + throw std::runtime_error("'double' is not supported in '" + + dev.get_info() + + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error("'half' is not supported in '" + + dev.get_info() + + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: \ + return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string + { + switch (AspectNum) + { +#include +#include + default: + return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw std::runtime_error( + "'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } + } + + static inline unsigned int get_current_device_id() + { + return dev_mgr::instance().current_device_id(); + } + + static inline device_ext &get_current_device() + { + return dev_mgr::instance().current_device(); + } + + static inline sycl::queue &get_in_order_queue() + { + return dev_mgr::instance().current_device().in_order_queue(); + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; +#ifdef DPCT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + + switch (real_direction) + { + case host_to_host: + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); }); + case host_to_device: + { + auto alloc = mm.translate_ptr(to_ptr); + size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(from_ptr, acc); }); + } + case device_to_host: + { + auto alloc = mm.translate_ptr(from_ptr); + size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(acc, to_ptr); }); + } + case device_to_device: + { + auto to_alloc = mm.translate_ptr(to_ptr); + auto from_alloc = mm.translate_ptr(from_ptr); + size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, r, to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, r, from_o); + cgh.copy(from_acc, to_acc); }); + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } +#else + return q.memcpy(to_ptr, from_ptr, size, dep_events); +#endif // DPCT_USM_LEVEL_NONE + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: +#ifdef DPCT_USM_LEVEL_NONE + { + auto &mm = mem_mgr::instance(); + auto to_alloc = mm.translate_ptr(to_surface); + auto from_alloc = mm.translate_ptr(from_surface); + size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr; + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, + get_copy_range(size, to_slice, to_range.get(0)), to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, + get_copy_range(size, from_slice, from_range.get(0)), from_o); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_acc[get_offset(id, to_slice, to_range.get(0))] = + from_acc[get_offset(id, from_slice, from_range.get(0))]; + }); })); + } +#else + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); +#endif + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, const void *b, library_data_t b_type, int ldb, + const void *beta, void *c, library_data_t c_type, int ldc, + library_data_t scaling_type) + { + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } // gemm() + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a[], + library_data_t a_type, int lda, const void *b[], + library_data_t b_type, int ldb, const void *beta, + void *c[], library_data_t c_type, int ldc, + int batch_size, library_data_t scaling_type) + { +#ifdef DPCT_USM_LEVEL_NONE + throw std::runtime_error("this API is unsupported when USM level is none"); +#else + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, + a, lda, b, ldb, &beta_float, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, + batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +#endif + } + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] stride_a Stride between the different A matrices. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] stride_b Stride between the different B matrices. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] stride_c Stride between the different C matrices. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, long long int stride_a, const void *b, + library_data_t b_type, int ldb, long long int stride_b, + const void *beta, void *c, library_data_t c_type, + int ldc, long long int stride_c, int batch_size, + library_data_t scaling_type) + { + bool matched = false; + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } + + static inline void + async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, + size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, + direction); + } + + using err0 = detail::generic_error_type; + using err1 = detail::generic_error_type; + +} // COPY from DPCT head files \ No newline at end of file diff --git a/dpct/atomic.hpp b/dpct/atomic.hpp deleted file mode 100644 index 4b516f530..000000000 --- a/dpct/atomic.hpp +++ /dev/null @@ -1,842 +0,0 @@ -//==---- atomic.hpp -------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_ATOMIC_HPP__ -#define __DPCT_ATOMIC_HPP__ - -#include - -namespace dpct { - -/// Atomically add the value operand to the value at the addr and assign the -/// result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to add to the value at \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_add(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_add(operand); -} - -template -inline T1 atomic_fetch_add(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_add(operand); -} - -/// Atomically add the value operand to the value at the addr and assign the -/// result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to add to the value at \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_add(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_add(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_add(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_add(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_add(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_add(addr, operand, memoryOrder); -} - -/// Atomically subtract the value operand from the value at the addr and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to subtract from the value at \p addr -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_sub(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_sub(operand); -} - -template -inline T1 atomic_fetch_sub(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_sub(operand); -} - -/// Atomically subtract the value operand from the value at the addr and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to subtract from the value at \p addr -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_sub(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_sub(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_sub(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_sub(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_sub(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_sub(addr, operand, memoryOrder); -} - -/// Atomically perform a bitwise AND between the value operand and the value at the addr -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise AND operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_and(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_and(operand); -} - -template -inline T1 atomic_fetch_and(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_and(operand); -} - -/// Atomically perform a bitwise AND between the value operand and the value at the addr -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise AND operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_and(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_and(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_and(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_and(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_and(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_and(addr, operand, memoryOrder); -} - -/// Atomically or the value at the addr with the value operand, and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise OR operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_or(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_or(operand); -} - -template -inline T1 atomic_fetch_or(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_or(operand); -} - -/// Atomically or the value at the addr with the value operand, and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise OR operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_or(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_or(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_or(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_or(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_or(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_or(addr, operand, memoryOrder); -} - -/// Atomically xor the value at the addr with the value operand, and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise XOR operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_xor(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_xor(operand); -} - -template -inline T1 atomic_fetch_xor(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_xor(operand); -} - -/// Atomically xor the value at the addr with the value operand, and assign -/// the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to use in bitwise XOR operation with the value at the \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_xor(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_xor(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_xor(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_xor(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_xor(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_xor(addr, operand, memoryOrder); -} - -/// Atomically calculate the minimum of the value at addr and the value operand -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_min(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_min(operand); -} - -template -inline T1 atomic_fetch_min(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_min(operand); -} - -/// Atomically calculate the minimum of the value at addr and the value operand -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_min(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_min(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_min(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_min(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_min(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_min(addr, operand, memoryOrder); -} - -/// Atomically calculate the maximum of the value at addr and the value operand -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_max(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_max(operand); -} - -template -inline T1 atomic_fetch_max(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.fetch_max(operand); -} - -/// Atomically calculate the maximum of the value at addr and the value operand -/// and assign the result to the value at addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_fetch_max(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_max(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_max(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_max(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_fetch_max(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_fetch_max(addr, operand, memoryOrder); -} - -/// Atomically set \p operand to the value stored in \p addr, if old value stored in -/// \p addr is equal to zero or greater than \p operand, else decrease the value stored -/// in \p addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The threshold value. -/// \param memoryOrder The memory ordering used. -/// \returns The old value stored in \p addr. -template -inline unsigned int atomic_fetch_compare_dec(unsigned int *addr, - unsigned int operand) { - auto atm = sycl::atomic_ref(addr[0]); - unsigned int old; - - while (true) { - old = atm.load(); - if (old == 0 || old > operand) { - if (atm.compare_exchange_strong(old, operand)) - break; - } else if (atm.compare_exchange_strong(old, old - 1)) - break; - } - - return old; -} - -/// Atomically increment the value stored in \p addr if old value stored in \p -/// addr is less than \p operand, else set 0 to the value stored in \p addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The threshold value. -/// \param memoryOrder The memory ordering used. -/// \returns The old value stored in \p addr. -template -inline unsigned int atomic_fetch_compare_inc(unsigned int *addr, - unsigned int operand) { - auto atm = sycl::atomic_ref(addr[0]); - unsigned int old; - while (true) { - old = atm.load(); - if (old >= operand) { - if (atm.compare_exchange_strong(old, 0)) - break; - } else if (atm.compare_exchange_strong(old, old + 1)) - break; - } - return old; -} - -/// Atomically increment the value stored in \p addr if old value stored in \p -/// addr is less than \p operand, else set 0 to the value stored in \p addr. -/// \param [in, out] addr The pointer to the data. -/// \param operand The threshold value. -/// \param memoryOrder The memory ordering used. -/// \returns The old value stored in \p addr. -template -inline unsigned int -atomic_fetch_compare_inc(unsigned int *addr, unsigned int operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_fetch_compare_inc(addr, - operand); - case sycl::memory_order::acq_rel: - return atomic_fetch_compare_inc(addr, - operand); - case sycl::memory_order::seq_cst: - return atomic_fetch_compare_inc(addr, - operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -/// Atomically exchange the value at the address addr with the value operand. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to be exchanged with the value pointed by \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_exchange(T *addr, T operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.exchange(operand); -} - -template -inline T1 atomic_exchange(T1 *addr, T2 operand) { - auto atm = - sycl::atomic_ref(addr[0]); - return atm.exchange(operand); -} - -/// Atomically exchange the value at the address addr with the value operand. -/// \param [in, out] addr The pointer to the data. -/// \param operand The value to be exchanged with the value pointed by \p addr. -/// \param memoryOrder The memory ordering used. -/// \returns The value at the \p addr before the call. -template -inline T atomic_exchange(T *addr, T operand, - sycl::memory_order memoryOrder) { - switch (memoryOrder) { - case sycl::memory_order::relaxed: - return atomic_exchange(addr, operand); - case sycl::memory_order::acq_rel: - return atomic_exchange(addr, operand); - case sycl::memory_order::seq_cst: - return atomic_exchange(addr, operand); - default: - assert(false && "Invalid memory_order for atomics. Valid memory_order for " - "atomics are: sycl::memory_order::relaxed, " - "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!"); - } -} - -template -inline T1 atomic_exchange(T1 *addr, T2 operand, - sycl::memory_order memoryOrder) { - atomic_exchange(addr, operand, memoryOrder); -} - -/// Atomically compare the value at \p addr to the value expected and exchange -/// with the value desired if the value at \p addr is equal to the value expected. -/// Returns the value at the \p addr before the call. -/// \param [in, out] addr Multi_ptr. -/// \param expected The value to compare against the value at \p addr. -/// \param desired The value to assign to \p addr if the value at \p addr is expected. -/// \param success The memory ordering used when comparison succeeds. -/// \param fail The memory ordering used when comparison fails. -/// \returns The value at the \p addr before the call. -template -T atomic_compare_exchange_strong( - sycl::multi_ptr addr, T expected, T desired, - sycl::memory_order success = sycl::memory_order::relaxed, - sycl::memory_order fail = sycl::memory_order::relaxed) { - auto atm = sycl::atomic_ref(*addr); - - atm.compare_exchange_strong(expected, desired, success, fail); - return expected; -} - -template -T1 atomic_compare_exchange_strong( - sycl::multi_ptr addr, T2 expected, T3 desired, - sycl::memory_order success = sycl::memory_order::relaxed, - sycl::memory_order fail = sycl::memory_order::relaxed) { - auto atm = - sycl::atomic_ref(*addr); - T1 expected_value = expected; - atm.compare_exchange_strong(expected_value, desired, success, fail); - return expected_value; -} - -/// Atomically compare the value at \p addr to the value expected and exchange -/// with the value desired if the value at \p addr is equal to the value expected. -/// Returns the value at the \p addr before the call. -/// \param [in] addr The pointer to the data. -/// \param expected The value to compare against the value at \p addr. -/// \param desired The value to assign to \p addr if the value at \p addr is expected. -/// \param success The memory ordering used when comparison succeeds. -/// \param fail The memory ordering used when comparison fails. -/// \returns The value at the \p addr before the call. -template -T atomic_compare_exchange_strong( - T *addr, T expected, T desired, - sycl::memory_order success = sycl::memory_order::relaxed, - sycl::memory_order fail = sycl::memory_order::relaxed) { - auto atm = - sycl::atomic_ref(addr[0]); - atm.compare_exchange_strong(expected, desired, success, fail); - return expected; -} - -template -T1 atomic_compare_exchange_strong( - T1 *addr, T2 expected, T3 desired, - sycl::memory_order success = sycl::memory_order::relaxed, - sycl::memory_order fail = sycl::memory_order::relaxed) { - T1 expected_value = expected; - auto atm = - sycl::atomic_ref(addr[0]); - atm.compare_exchange_strong(expected_value, desired, success, fail); - return expected_value; -} - -/// Atomic extension to implement standard APIs in std::atomic -namespace detail{ -template struct IsValidAtomicType { - static constexpr bool value = - (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_pointer::value); -}; -} // namespace detail - -template -class atomic{ - static_assert( - detail::IsValidAtomicType::value, - "Invalid atomic type. Valid types are int, unsigned int, long, " - "unsigned long, long long, unsigned long long, float, double " - "and pointer types"); - T __d; - -public: - /// default memory synchronization order - static constexpr sycl::memory_order default_read_order = - sycl::atomic_ref::default_read_order; - static constexpr sycl::memory_order default_write_order = - sycl::atomic_ref::default_write_order; - static constexpr sycl::memory_scope default_scope = DefaultScope; - static constexpr sycl::memory_order default_read_modify_write_order = - DefaultOrder; - - - /// Default constructor. - constexpr atomic() noexcept = default; - /// Constructor with initialize value. - constexpr atomic(T d) noexcept : __d(d){}; - - /// atomically replaces the value of the referenced object with a non-atomic argument - /// \param operand The value to replace the pointed value. - /// \param memoryOrder The memory ordering used. - /// \param memoryScope The memory scope used. - void store(T operand, sycl::memory_order memoryOrder = default_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - sycl::atomic_ref atm(__d); - atm.store(operand, memoryOrder, memoryScope); - } - - /// atomically obtains the value of the referenced object - /// \param memoryOrder The memory ordering used. - /// \param memoryScope The memory scope used. - /// \returns The value of the referenced object - T load(sycl::memory_order memoryOrder = default_read_order, - sycl::memory_scope memoryScope = default_scope) const noexcept { - sycl::atomic_ref atm( - const_cast(__d)); - return atm.load(memoryOrder, memoryScope); - } - - /// atomically replaces the value of the referenced object and obtains the value held previously - /// \param operand The value to replace the pointed value. - /// \param memoryOrder The memory ordering used. - /// \param memoryScope The memory scope used. - /// \returns The value of the referenced object before the call. - T exchange(T operand, - sycl::memory_order memoryOrder = default_read_modify_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - - sycl::atomic_ref atm(__d); - return atm.exchange(operand, memoryOrder, memoryScope); - } - - /// atomically compares the value of the referenced object with non-atomic argument - /// and performs atomic exchange if equal or atomic load if not - /// \param expected The value expected to be found in the object referenced by the atomic_ref object - /// \param desired The value to store in the referenced object if it is as expected - /// \param success The memory models for the read-modify-write - /// \param failure The memory models for load operations - /// \param memoryScope The memory scope used. - /// \returns true if the referenced object was successfully changed, false otherwise. - bool compare_exchange_weak( - T &expected, T desired, - sycl::memory_order success, sycl::memory_order failure, - sycl::memory_scope memoryScope = default_scope) noexcept { - sycl::atomic_ref atm(__d); - return atm.compare_exchange_weak(expected, desired, success, failure, memoryScope); - } - /// \param expected The value expected to be found in the object referenced by the atomic_ref object - /// \param desired The value to store in the referenced object if it is as expected - /// \param memoryOrder The memory synchronization ordering for operations - /// \param memoryScope The memory scope used. - /// \returns true if the referenced object was successfully changed, false otherwise. - bool compare_exchange_weak(T &expected, T desired, - sycl::memory_order memoryOrder = default_read_modify_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - sycl::atomic_ref atm(__d); - return atm.compare_exchange_weak(expected, desired, memoryOrder, memoryScope); - } - - /// atomically compares the value of the referenced object with non-atomic argument - /// and performs atomic exchange if equal or atomic load if not - /// \param expected The value expected to be found in the object referenced by the atomic_ref object - /// \param desired The value to store in the referenced object if it is as expected - /// \param success The memory models for the read-modify-write - /// \param failure The memory models for load operations - /// \param memoryScope The memory scope used. - /// \returns true if the referenced object was successfully changed, false otherwise. - bool compare_exchange_strong( - T &expected, T desired, - sycl::memory_order success, sycl::memory_order failure, - sycl::memory_scope memoryScope = default_scope) noexcept { - - sycl::atomic_ref atm(__d); - return atm.compare_exchange_strong(expected, desired, success, failure, memoryScope); - } - /// \param expected The value expected to be found in the object referenced by the atomic_ref object - /// \param desired The value to store in the referenced object if it is as expected - /// \param memoryOrder The memory synchronization ordering for operations - /// \param memoryScope The memory scope used. - /// \returns true if the referenced object was successfully changed, false otherwise. - bool compare_exchange_strong(T &expected, T desired, - sycl::memory_order memoryOrder = default_read_modify_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - sycl::atomic_ref atm(__d); - return atm.compare_exchange_strong(expected, desired, memoryOrder, memoryScope); - } - - /// atomically adds the argument to the value stored in the atomic object and obtains the value held previously - /// \param operand The other argument of arithmetic addition - /// \param memoryOrder The memory ordering used. - /// \param memoryScope The memory scope used. - /// \returns The value of the referenced object before the call. - T fetch_add(T operand, - sycl::memory_order memoryOrder = default_read_modify_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - - sycl::atomic_ref atm(__d); - return atm.fetch_add(operand, memoryOrder, memoryScope); - } - - /// atomically subtracts the argument from the value stored in the atomic object and obtains the value held previously - /// \param operand The other argument of arithmetic subtraction - /// \param memoryOrder The memory ordering used. - /// \param memoryScope The memory scope used. - /// \returns The value of the referenced object before the call. - T fetch_sub(T operand, - sycl::memory_order memoryOrder = default_read_modify_write_order, - sycl::memory_scope memoryScope = default_scope) noexcept { - - sycl::atomic_ref atm(__d); - return atm.fetch_sub(operand, memoryOrder, memoryScope); - } -}; - -} // namespace dpct -#endif // __DPCT_ATOMIC_HPP__ diff --git a/dpct/blas_utils.hpp b/dpct/blas_utils.hpp deleted file mode 100644 index df222c528..000000000 --- a/dpct/blas_utils.hpp +++ /dev/null @@ -1,1792 +0,0 @@ -//==---- blas_utils.hpp----------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_BLAS_UTILS_HPP__ -#define __DPCT_BLAS_UTILS_HPP__ - -#include "memory.hpp" -#include "util.hpp" -#include "lib_common_utils.hpp" -#include -#include -#include -#include -#include - -namespace dpct { - -/// Get the value of \p s. -/// Copy the data to host synchronously, then return the data. -/// \param [in] p The pointer points the data. -/// \param [in] q The queue where the memory copy should be executed. -template -inline auto get_value(const T *s, sycl::queue &q) { - return detail::get_value(s, q); -} - -namespace detail { -inline void mem_free(sycl::queue *exec_queue, - std::vector pointers_array, sycl::event e) { - e.wait(); - for (auto p : pointers_array) - sycl::free(p, *exec_queue); -} - -inline int stride_for(int num_elems, int mem_align_in_elems) { - return ((num_elems - 1) / mem_align_in_elems + 1) * mem_align_in_elems; -} - -#ifndef DPCT_USM_LEVEL_NONE -template -class working_memory { - T *_input_ptr; - T *_temp_ptr; - bool _is_sycl_malloced = false; - bool _is_scalar_value = false; - sycl::queue _q; - sycl::event _e; - -public: - working_memory(size_t size, sycl::queue q) : _q(q) { - _is_scalar_value = false; - _temp_ptr = (T *)sycl::malloc_device(size, q); - } - working_memory(T *result_ptr, sycl::queue q) : _input_ptr(result_ptr), _q(q) { - _is_scalar_value = true; - _is_sycl_malloced = sycl::get_pointer_type(_input_ptr, _q.get_context()) != - sycl::usm::alloc::unknown; - if (!_is_sycl_malloced) - _temp_ptr = sycl::malloc_shared(1, _q); - } - auto get_ptr() { - if (_is_scalar_value && _is_sycl_malloced) - return _input_ptr; - return _temp_ptr; - } - void set_event(sycl::event e) { _e = e; } - ~working_memory() { - if (_is_scalar_value) { - if (!_is_sycl_malloced) { - _q.memcpy(_input_ptr, _temp_ptr, sizeof(T)).wait(); - sycl::free(_temp_ptr, _q); - } - } else { - std::vector ptrs{_temp_ptr}; - dpct::async_dpct_free(ptrs, {_e}); - } - } -}; -#endif - -template -inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx, - void *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else -#ifdef DPCT_USM_LEVEL_NONE - auto x_buffer = dpct::get_buffer(x); - auto r_buffer = - sycl::buffer(reinterpret_cast(result), sycl::range<1>(1)); - if (dpct::is_device_ptr(result)) - r_buffer = dpct::get_buffer(result); - oneapi::mkl::blas::column_major::nrm2(q, n, x_buffer, incx, r_buffer); -#else - working_memory res_mem(reinterpret_cast(result), q); - oneapi::mkl::blas::column_major::nrm2(q, n, reinterpret_cast(x), - incx, res_mem.get_ptr()); -#endif -#endif -} - -template -inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx, - const Txy *y, int incy, Tr *result) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else -#ifdef DPCT_USM_LEVEL_NONE - auto x_buffer = dpct::get_buffer(x); - auto y_buffer = dpct::get_buffer(y); - auto r_buffer = sycl::buffer((Tr *)result, sycl::range<1>(1)); - if (dpct::is_device_ptr(result)) - r_buffer = dpct::get_buffer(result); - if constexpr (std::is_same_v> || - std::is_same_v>) { - if constexpr (is_conjugate) - oneapi::mkl::blas::column_major::dotc(q, n, x_buffer, incx, y_buffer, - incy, r_buffer); - else - oneapi::mkl::blas::column_major::dotu(q, n, x_buffer, incx, y_buffer, - incy, r_buffer); - } else - oneapi::mkl::blas::column_major::dot(q, n, x_buffer, incx, y_buffer, incy, - r_buffer); -#else - working_memory res_mem(result, q); - if constexpr (std::is_same_v> || - std::is_same_v>) { - if constexpr (is_conjugate) - oneapi::mkl::blas::column_major::dotc(q, n, x, incx, y, incy, res_mem.get_ptr()); - else - oneapi::mkl::blas::column_major::dotu(q, n, x, incx, y, incy, res_mem.get_ptr()); - } else - oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr()); -#endif -#endif -} - -template -inline void dotuc(sycl::queue &q, int n, const void *x, - library_data_t x_type, int incx, const void *y, - library_data_t y_type, int incy, void *result, - library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, y_type, result_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float, - library_data_t::complex_float): { - detail::dotuc_impl( - q, n, reinterpret_cast *>(x), incx, - reinterpret_cast *>(y), incy, - reinterpret_cast *>(result)); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double, - library_data_t::complex_double): { - detail::dotuc_impl( - q, n, reinterpret_cast *>(x), incx, - reinterpret_cast *>(y), incy, - reinterpret_cast *>(result)); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half): { - detail::dotuc_impl( - q, n, reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -template -inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x, - int incx) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); - auto data_x = get_memory(x); - oneapi::mkl::blas::column_major::scal(q, n, alpha_val, - data_x, incx); -#endif -} - -template -inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x, - int incx, void *y, int incy) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - Te alpha_val = dpct::get_value(reinterpret_cast(alpha), q); - auto data_x = get_memory(x); - auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::axpy(q, n, alpha_val, - data_x, incx, - data_y, incy); -#endif -} - -template -inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y, - int incy, const void *c, const void *s) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - Tc c_value = dpct::get_value(reinterpret_cast(c), q); - Ts s_value = dpct::get_value(reinterpret_cast(s), q); - auto data_x = get_memory(x); - auto data_y = get_memory(y); - oneapi::mkl::blas::column_major::rot(q, n, data_x, incx, - data_y, incy, c_value, - s_value); -#endif -} - -template -inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, int lda, const void *b, - int ldb, const void *beta, void *c, int ldc) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); -#endif -} - -template -inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) { - struct matrix_info_t { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->transpose_info[0] = a_trans; - matrix_info->transpose_info[1] = b_trans; - matrix_info->value_info[0] = alpha_value; - matrix_info->value_info[1] = beta_value; - matrix_info->size_info[0] = m; - matrix_info->size_info[1] = n; - matrix_info->size_info[2] = k; - matrix_info->ld_info[0] = lda; - matrix_info->ld_info[1] = ldb; - matrix_info->ld_info[2] = ldc; - matrix_info->groupsize_info = batch_size; - - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); - - q.submit([&](sycl::handler &cgh) { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); - }); -} - -template -inline void -gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, - int k, const void *alpha, const void *a, int lda, - long long int stride_a, const void *b, int ldb, - long long int stride_b, const void *beta, void *c, - int ldc, long long int stride_c, int batch_size) { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); -} - -template -inline void rk_impl(sycl::queue &q, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, int n, int k, - const T *alpha, const T *a, int lda, const T *b, - int ldb, const Tbeta *beta, T *c, int ldc) { - // For symmetric matrix, this function performs: C = alpha*OP(A)*(OP(B))^T + beta*C - // For Hermitian matrix, this function performs: C = alpha*OP(A)*(OP(B))^H + beta*C - // The gemmt() function performs: C = alpha*OPA(A)*OPB(B) + beta*C - // So the OPB need be updated before we call gemmt(). - using Ty = typename dpct::DataType::T2; - using Ts = typename dpct::DataType::T2; - Ty alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - oneapi::mkl::transpose trans_A = trans, trans_B = trans; - int origin_b_rows = trans == oneapi::mkl::transpose::nontrans ? n : k; - int origin_b_cols = trans == oneapi::mkl::transpose::nontrans ? k : n; - - if ((is_hermitian && trans == oneapi::mkl::transpose::trans) || - (!is_hermitian && !std::is_floating_point_v && trans == oneapi::mkl::transpose::conjtrans)) { - // In this case, OPB need be a conjugate operation, - // but only notrans, conjtrans and trans are available. - // So we need do a conjtrans operation first, then do a trans operation. - trans_B = oneapi::mkl::transpose::trans; - auto data_a = get_memory(a); - auto data_c = get_memory(c); -#ifdef DPCT_USM_LEVEL_NONE - auto new_B_buffer = sycl::buffer(sycl::range<1>(origin_b_rows * origin_b_cols)); - auto from_buffer = dpct::get_buffer(b); - oneapi::mkl::blas::column_major::omatcopy_batch( - q, oneapi::mkl::transpose::conjtrans, origin_b_rows, origin_b_cols, - Ts(1.0), from_buffer, ldb, origin_b_rows * ldb, new_B_buffer, - origin_b_cols, origin_b_rows * origin_b_cols, 1); - oneapi::mkl::blas::column_major::gemmt( - q, uplo, trans_A, trans_B, n, k, alpha_value, - data_a, lda, new_B_buffer, origin_b_cols, beta_value, data_c, ldc); -#else - working_memory new_B(origin_b_rows * origin_b_cols * sizeof(T), q); - oneapi::mkl::blas::column_major::omatcopy_batch( - q, oneapi::mkl::transpose::conjtrans, origin_b_rows, origin_b_cols, - Ts(1.0), reinterpret_cast(b), ldb, origin_b_rows * ldb, - reinterpret_cast(new_B.get_ptr()), origin_b_cols, - origin_b_rows * origin_b_cols, 1); - sycl::event e = oneapi::mkl::blas::column_major::gemmt( - q, uplo, trans_A, trans_B, n, k, alpha_value, - data_a, lda, reinterpret_cast(new_B.get_ptr()), origin_b_cols, - beta_value, data_c, ldc); - new_B.set_event(e); -#endif - } else { - if constexpr (is_hermitian) { - trans_B = trans == oneapi::mkl::transpose::nontrans - ? oneapi::mkl::transpose::conjtrans - : oneapi::mkl::transpose::nontrans; - } else { - trans_B = trans == oneapi::mkl::transpose::nontrans - ? oneapi::mkl::transpose::trans - : oneapi::mkl::transpose::nontrans; - } - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemmt( - q, uplo, trans_A, trans_B, n, k, alpha_value, - data_a, lda, data_b, ldb, beta_value, data_c, ldc); - } -} - -template -inline void -trsm_batch_impl(sycl::queue &q, oneapi::mkl::side left_right, - oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, - oneapi::mkl::diag unit_diag, int m, int n, const void *alpha, - const void **a, int lda, void **b, int ldb, int batch_size) { - struct matrix_info_t { - matrix_info_t(oneapi::mkl::side side_info, oneapi::mkl::uplo uplo_info, - oneapi::mkl::transpose transpose_info, - oneapi::mkl::diag diag_info, Ts value_info, std::int64_t m, - std::int64_t n, std::int64_t lda, std::int64_t ldb, - std::int64_t groupsize_info) - : side_info(side_info), uplo_info(uplo_info), - transpose_info(transpose_info), diag_info(diag_info), - value_info(value_info), groupsize_info(groupsize_info) { - size_info[0] = m; - size_info[1] = n; - ld_info[0] = lda; - ld_info[1] = ldb; - } - oneapi::mkl::side side_info; - oneapi::mkl::uplo uplo_info; - oneapi::mkl::transpose transpose_info; - oneapi::mkl::diag diag_info; - Ts value_info; - std::int64_t size_info[2]; - std::int64_t ld_info[2]; - std::int64_t groupsize_info; - }; - - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - - matrix_info_t *matrix_info = - new matrix_info_t(left_right, upper_lower, trans, unit_diag, alpha_value, - m, n, lda, ldb, batch_size); - - sycl::event e = oneapi::mkl::blas::column_major::trsm_batch( - q, &(matrix_info->side_info), &(matrix_info->uplo_info), - &(matrix_info->transpose_info), &(matrix_info->diag_info), - matrix_info->size_info, matrix_info->size_info + 1, - &(matrix_info->value_info), reinterpret_cast(a), - matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, 1, &(matrix_info->groupsize_info)); - - q.submit([&](sycl::handler &cgh) { - cgh.depends_on(e); - cgh.host_task([=] { delete matrix_info; }); - }); -} - -template -inline void getrfnp_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], - int lda, int *info, int batch_size) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); -#else - using Ty = typename DataType::T2; - // Set the info array value to 0 - detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); - std::int64_t stride_a = n * lda; - std::int64_t scratchpad_size = - oneapi::mkl::lapack::getrfnp_batch_scratchpad_size( - exec_queue, n, n, lda, stride_a, batch_size); - - Ty *a_strided_mem = - (Ty *)dpct::dpct_malloc(stride_a * batch_size * sizeof(Ty), exec_queue); - T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); - dpct::dpct_memcpy(host_a, a, batch_size * sizeof(T *)); - for (std::int64_t i = 0; i < batch_size; ++i) - dpct::dpct_memcpy(a_strided_mem + i * stride_a, host_a[i], - n * lda * sizeof(T)); - -#ifdef DPCT_USM_LEVEL_NONE - { - sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; - auto a_buffer = get_buffer(a_strided_mem); - oneapi::mkl::lapack::getrfnp_batch(exec_queue, n, n, a_buffer, lda, - stride_a, batch_size, scratchpad, - scratchpad_size); - } - std::vector events; - for (std::int64_t i = 0; i < batch_size; ++i) - events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], - a_strided_mem + i * stride_a, - n * lda * sizeof(T), automatic)); -#else - Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); - sycl::event e = oneapi::mkl::lapack::getrfnp_batch( - exec_queue, n, n, a_strided_mem, lda, stride_a, batch_size, scratchpad, - scratchpad_size); - std::vector events; - for (std::int64_t i = 0; i < batch_size; ++i) - events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], - a_strided_mem + i * stride_a, - n * lda * sizeof(T), automatic, {e})); - - std::vector ptrs{scratchpad, a_strided_mem}; - dpct::async_dpct_free(ptrs, events, exec_queue); -#endif - - exec_queue.submit([&](sycl::handler &cgh) { - cgh.depends_on(events); - cgh.host_task([=] { std::free(host_a); }); - }); -#endif -} - -} // namespace detail - -inline oneapi::mkl::transpose get_transpose(int t) { - if (t == 0) { - return oneapi::mkl::transpose::nontrans; - } else if (t == 1) { - return oneapi::mkl::transpose::trans; - } else { - return oneapi::mkl::transpose::conjtrans; - } -} - -/// Computes the LU factorizations of a batch of general matrices. -/// \param [in] exec_queue The queue where the routine should be executed. -/// \param [in] n The order of the matrices. -/// \param [in, out] a Array of pointers to matrices. These matrices will be -/// overwritten by lower triangulars with unit diagonal elements and upper -/// triangulars. -/// \param [in] lda The leading dimension of the matrices. -/// \param [out] ipiv An array stores the pivot indices. If \p ipiv is nullptr, -/// non-pivoting LU factorization is computed. -/// \param [out] info An array stores the error information. -/// \param [in] batch_size The size of the batch. -template -inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], - int lda, int *ipiv, int *info, int batch_size) { - if (ipiv == nullptr) { - detail::getrfnp_batch_wrapper(exec_queue, n, a, lda, info, batch_size); - return; - } - using Ty = typename DataType::T2; - // Set the info array value to 0 - detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); -#ifdef DPCT_USM_LEVEL_NONE - std::int64_t stride_a = n * lda; - std::int64_t stride_ipiv = n; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size( - exec_queue, n, n, lda, stride_a, stride_ipiv, batch_size); - - T *a_buffer_ptr; - a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); - - T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); - dpct_memcpy(host_a, a, batch_size * sizeof(T *)); - for (std::int64_t i = 0; i < batch_size; ++i) - dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); - - { - sycl::buffer ipiv_buf( - sycl::range<1>(batch_size * stride_ipiv)); - sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; - auto a_buffer = get_buffer(a_buffer_ptr); - oneapi::mkl::lapack::getrf_batch(exec_queue, n, n, a_buffer, lda, stride_a, - ipiv_buf, stride_ipiv, batch_size, scratchpad, - scratchpad_size); - - auto to_buffer = get_buffer(ipiv); - exec_queue.submit([&](sycl::handler &cgh) { - auto from_acc = ipiv_buf.get_access(cgh); - auto to_acc = to_buffer.get_access(cgh); - cgh.parallel_for>( - sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { - to_acc[id.get(0) * n + id.get(1)] = - static_cast(from_acc[id.get(0) * stride_ipiv + id.get(1)]); - }); - }); - } - - // Copy back to the original buffers - std::vector events; - for (std::int64_t i = 0; i < batch_size; ++i) - events.push_back(detail::dpct_memcpy(exec_queue, host_a[i], - a_buffer_ptr + i * stride_a, - n * lda * sizeof(T), automatic)); - - std::vector ptrs{host_a}; - std::thread mem_free_thread( - [=](std::vector pointers_array, - std::vector events_array) { - sycl::event::wait(events_array); - for (auto p : pointers_array) - std::free(p); - }, - ptrs, events); - mem_free_thread.detach(); -#else - std::int64_t m_int64 = n; - std::int64_t n_int64 = n; - std::int64_t lda_int64 = lda; - std::int64_t group_sizes = batch_size; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size( - exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes); - - Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); - std::int64_t *ipiv_int64 = - sycl::malloc_device(batch_size * n, exec_queue); - std::int64_t **ipiv_int64_ptr = - sycl::malloc_shared(batch_size, exec_queue); - T **a_shared = sycl::malloc_shared(batch_size, exec_queue); - exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)).wait(); - for (std::int64_t i = 0; i < batch_size; ++i) - ipiv_int64_ptr[i] = ipiv_int64 + n * i; - - oneapi::mkl::lapack::getrf_batch(exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64, - ipiv_int64_ptr, 1, &group_sizes, scratchpad, - scratchpad_size); - - sycl::event e = exec_queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for>( - sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { - ipiv[idx] = static_cast(ipiv_int64[idx]); - }); - }); - - std::vector ptrs{scratchpad, ipiv_int64, ipiv_int64_ptr, a_shared}; - async_dpct_free(ptrs, {e}, exec_queue); -#endif -} - -/// Solves a system of linear equations with a batch of LU-factored square -/// coefficient matrices, with multiple right-hand sides. -/// \param [in] exec_queue The queue where the routine should be executed. -/// \param [in] trans Indicates the form of the linear equations. -/// \param [in] n The order of the matrices. -/// \param [in] nrhs The number of right hand sides. -/// \param [in] a Array of pointers to matrices. -/// \param [in] lda The leading dimension of the matrices in \p a. -/// \param [in] ipiv An array stores the pivots. -/// \param [in, out] b Array of pointers to matrices, whose columns are -/// the right-hand sides for the systems of equations. -/// \param [in] ldb The leading dimension of the matrices in \p b. -/// \param [out] info A value stores the error information. -/// \param [in] batch_size The size of the batch. -template -inline void getrs_batch_wrapper(sycl::queue &exec_queue, - oneapi::mkl::transpose trans, int n, int nrhs, - const T *a[], int lda, const int *ipiv, T *b[], - int ldb, int *info, int batch_size) { - using Ty = typename DataType::T2; - // Set the info value to 0 - *info = 0; -#ifdef DPCT_USM_LEVEL_NONE - std::int64_t stride_a = n * lda; - std::int64_t stride_b = nrhs * ldb; - std::int64_t stride_ipiv = n; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size( - exec_queue, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, - batch_size); - - T *a_buffer_ptr, *b_buffer_ptr; - a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); - b_buffer_ptr = (T *)dpct_malloc(stride_b * batch_size * sizeof(T)); - - T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); - T **host_b = (T **)std::malloc(batch_size * sizeof(T *)); - dpct_memcpy(host_a, a, batch_size * sizeof(T *)); - dpct_memcpy(host_b, b, batch_size * sizeof(T *)); - for (std::int64_t i = 0; i < batch_size; ++i) { - dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); - dpct_memcpy(b_buffer_ptr + i * stride_b, host_b[i], nrhs * ldb * sizeof(T)); - } - - { - auto a_buffer = get_buffer(a_buffer_ptr); - auto b_buffer = get_buffer(b_buffer_ptr); - sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; - sycl::buffer ipiv_buf( - sycl::range<1>(batch_size * stride_ipiv)); - auto from_buf = get_buffer(ipiv); - exec_queue.submit([&](sycl::handler &cgh) { - auto from_acc = from_buf.get_access(cgh); - auto to_acc = ipiv_buf.get_access(cgh); - cgh.parallel_for>( - sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { - to_acc[id.get(0) * stride_ipiv + id.get(1)] = - static_cast(from_acc[id.get(0) * n + id.get(1)]); - }); - }); - - oneapi::mkl::lapack::getrs_batch(exec_queue, trans, n, nrhs, a_buffer, lda, - stride_a, ipiv_buf, stride_ipiv, b_buffer, ldb, - stride_b, batch_size, scratchpad, scratchpad_size); - } - - // Copy back to the original buffers - std::vector events; - for (std::int64_t i = 0; i < batch_size; ++i) - events.push_back(detail::dpct_memcpy(exec_queue, host_b[i], - b_buffer_ptr + i * stride_b, - nrhs * ldb * sizeof(T), automatic)); - std::vector ptrs{host_a, host_b}; - std::thread mem_free_thread( - [=](std::vector pointers_array, - std::vector events_array) { - sycl::event::wait(events_array); - for (auto p : pointers_array) - std::free(p); - }, - ptrs, events); - mem_free_thread.detach(); -#else - std::int64_t n_int64 = n; - std::int64_t nrhs_int64 = nrhs; - std::int64_t lda_int64 = lda; - std::int64_t ldb_int64 = ldb; - std::int64_t group_sizes = batch_size; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size( - exec_queue, &trans, &n_int64, &nrhs_int64, &lda_int64, &ldb_int64, 1, - &group_sizes); - - Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); - std::int64_t *ipiv_int64 = - sycl::malloc_device(batch_size * n, exec_queue); - std::int64_t **ipiv_int64_ptr = - sycl::malloc_shared(batch_size, exec_queue); - T **a_shared = sycl::malloc_shared(batch_size, exec_queue); - T **b_shared = sycl::malloc_shared(batch_size, exec_queue); - exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); - exec_queue.memcpy(b_shared, b, batch_size * sizeof(T *)); - - exec_queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for>( - sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { - ipiv_int64[idx] = static_cast(ipiv[idx]); - }); - }).wait(); - - for (std::int64_t i = 0; i < batch_size; ++i) - ipiv_int64_ptr[i] = ipiv_int64 + n * i; - - sycl::event e = oneapi::mkl::lapack::getrs_batch( - exec_queue, &trans, &n_int64, &nrhs_int64, (Ty **)a_shared, &lda_int64, - ipiv_int64_ptr, (Ty **)b_shared, &ldb_int64, 1, &group_sizes, scratchpad, - scratchpad_size); - - std::vector ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared, b_shared}; - async_dpct_free(ptrs, {e}, exec_queue); -#endif -} - -/// Computes the inverses of a batch of LU-factored matrices. -/// \param [in] exec_queue The queue where the routine should be executed. -/// \param [in] n The order of the matrices. -/// \param [in] a Array of pointers to matrices. -/// \param [in] lda The leading dimension of the matrices in \p a. -/// \param [in] ipiv An array stores the pivots. -/// \param [out] b Array of pointers to inverse matrices. -/// \param [in] ldb The leading dimension of the matrices in \p b. -/// \param [out] info An array stores the error information. -/// \param [in] batch_size The size of the batch. -template -inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, - const T *a[], int lda, int *ipiv, T *b[], - int ldb, int *info, int batch_size) { - using Ty = typename DataType::T2; - // Set the info array value to 0 - detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size); -#ifdef DPCT_USM_LEVEL_NONE - std::int64_t stride_b = n * ldb; - std::int64_t stride_ipiv = n; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size( - exec_queue, n, ldb, stride_b, stride_ipiv, batch_size); - - T *b_buffer_ptr; - b_buffer_ptr = (T *)dpct_malloc(stride_b * batch_size * sizeof(T)); - - T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); - T **host_b = (T **)std::malloc(batch_size * sizeof(T *)); - dpct_memcpy(host_a, a, batch_size * sizeof(T *)); - dpct_memcpy(host_b, b, batch_size * sizeof(T *)); - - for (std::int64_t i = 0; i < batch_size; ++i) { - // Need to create a copy of input matrices "a" to keep them unchanged. - // Matrices "b" (copy of matrices "a") will be used as input and output - // parameter in oneapi::mkl::lapack::getri_batch call. - matrix_mem_copy(b_buffer_ptr + i * stride_b, host_a[i], ldb, lda, n, n, - dpct::device_to_device, exec_queue); - } - - { - auto b_buffer = get_buffer(b_buffer_ptr); - sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; - sycl::buffer ipiv_buf( - sycl::range<1>(batch_size * stride_ipiv)); - auto from_buf = get_buffer(ipiv); - exec_queue.submit([&](sycl::handler &cgh) { - auto from_acc = from_buf.get_access(cgh); - auto to_acc = ipiv_buf.get_access(cgh); - cgh.parallel_for>( - sycl::range<2>(batch_size, n), [=](sycl::id<2> id) { - to_acc[id.get(0) * stride_ipiv + id.get(1)] = - static_cast(from_acc[id.get(0) * n + id.get(1)]); - }); - }); - - oneapi::mkl::lapack::getri_batch(exec_queue, n, b_buffer, ldb, stride_b, ipiv_buf, - stride_ipiv, batch_size, scratchpad, - scratchpad_size); - } - - // Copy back to the original buffers - std::vector events; - for (std::int64_t i = 0; i < batch_size; ++i) - events.push_back(detail::dpct_memcpy(exec_queue, host_b[i], - b_buffer_ptr + i * stride_b, - n * ldb * sizeof(T), automatic)); - std::vector ptrs{host_a, host_b}; - std::thread mem_free_thread( - [=](std::vector pointers_array, - std::vector events_array) { - sycl::event::wait(events_array); - for (auto p : pointers_array) - std::free(p); - }, - ptrs, events); - mem_free_thread.detach(); -#else - std::int64_t n_int64 = n; - std::int64_t ldb_int64 = ldb; - std::int64_t group_sizes = batch_size; - std::int64_t scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size( - exec_queue, &n_int64, &ldb_int64, 1, &group_sizes); - - Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); - std::int64_t *ipiv_int64 = - sycl::malloc_device(batch_size * n, exec_queue); - std::int64_t **ipiv_int64_ptr = - sycl::malloc_shared(batch_size, exec_queue); - - exec_queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for>( - sycl::range<1>(batch_size * n), [=](sycl::id<1> idx) { - ipiv_int64[idx] = static_cast(ipiv[idx]); - }); - }); - - T **a_shared = sycl::malloc_shared(batch_size, exec_queue); - T **b_shared = sycl::malloc_shared(batch_size, exec_queue); - exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); - exec_queue.memcpy(b_shared, b, batch_size * sizeof(T *)).wait(); - for (std::int64_t i = 0; i < batch_size; ++i) { - ipiv_int64_ptr[i] = ipiv_int64 + n * i; - // Need to create a copy of input matrices "a" to keep them unchanged. - // Matrices "b" (copy of matrices "a") will be used as input and output - // parameter in oneapi::mkl::lapack::getri_batch call. - matrix_mem_copy(b_shared[i], a_shared[i], ldb, lda, n, n, dpct::device_to_device, - exec_queue); - } - - sycl::event e = oneapi::mkl::lapack::getri_batch( - exec_queue, &n_int64, (Ty **)b_shared, &ldb_int64, ipiv_int64_ptr, 1, - &group_sizes, scratchpad, scratchpad_size); - - std::vector ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared, b_shared}; - async_dpct_free(ptrs, {e}, exec_queue); -#endif -} - -/// Computes the QR factorizations of a batch of general matrices. -/// \param [in] exec_queue The queue where the routine should be executed. -/// \param [in] m The number of rows in the matrices. -/// \param [in] n The number of columns in the matrices. -/// \param [in, out] a Array of pointers to matrices. These -/// matrices will be overwritten by the factorization data. -/// \param [in] lda The leading dimension of the matrices in \p a. -/// \param [out] tau An array stores the scalars. -/// \param [out] info A value stores the error information. -/// \param [in] batch_size The size of the batch. -template -inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, - T *a[], int lda, T *tau[], int *info, - int batch_size) { - using Ty = typename DataType::T2; - // Set the info value to 0 - *info = 0; -#ifdef DPCT_USM_LEVEL_NONE - std::int64_t stride_a = n * lda; - std::int64_t stride_tau = std::max(1, std::min(m, n)); - std::int64_t scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size( - exec_queue, m, n, lda, stride_a, stride_tau, batch_size); - - T *a_buffer_ptr, *tau_buffer_ptr; - a_buffer_ptr = (T *)dpct_malloc(stride_a * batch_size * sizeof(T)); - tau_buffer_ptr = (T *)dpct_malloc(stride_tau * batch_size * sizeof(T)); - - T **host_a = (T **)std::malloc(batch_size * sizeof(T *)); - T **host_tau = (T **)std::malloc(batch_size * sizeof(T *)); - dpct_memcpy(host_a, a, batch_size * sizeof(T *)); - dpct_memcpy(host_tau, tau, batch_size * sizeof(T *)); - - for (std::int64_t i = 0; i < batch_size; ++i) - dpct_memcpy(a_buffer_ptr + i * stride_a, host_a[i], n * lda * sizeof(T)); - { - auto a_buffer = get_buffer(a_buffer_ptr); - auto tau_buffer = get_buffer(tau_buffer_ptr); - sycl::buffer scratchpad{sycl::range<1>(scratchpad_size)}; - oneapi::mkl::lapack::geqrf_batch(exec_queue, m, n, a_buffer, lda, stride_a, - tau_buffer, stride_tau, batch_size, scratchpad, - scratchpad_size); - } - - // Copy back to the original buffers - std::vector events_a; - std::vector events_tau; - for (std::int64_t i = 0; i < batch_size; ++i) { - events_a.push_back(detail::dpct_memcpy(exec_queue, host_a[i], - a_buffer_ptr + i * stride_a, - n * lda * sizeof(T), automatic)); - events_tau.push_back(detail::dpct_memcpy( - exec_queue, host_tau[i], tau_buffer_ptr + i * stride_tau, - std::max(1, std::min(m, n)) * sizeof(T), automatic)); - } - std::vector ptr_a{host_a}; - std::vector ptr_tau{host_tau}; - std::thread mem_free_thread_a( - [=](std::vector pointers_array, - std::vector events_array) { - sycl::event::wait(events_array); - for (auto p : pointers_array) - std::free(p); - }, - ptr_a, events_a); - std::thread mem_free_thread_tau( - [=](std::vector pointers_array, - std::vector events_array) { - sycl::event::wait(events_array); - for (auto p : pointers_array) - std::free(p); - }, - ptr_tau, events_tau); - mem_free_thread_a.detach(); - mem_free_thread_tau.detach(); -#else - std::int64_t m_int64 = n; - std::int64_t n_int64 = n; - std::int64_t lda_int64 = lda; - std::int64_t group_sizes = batch_size; - std::int64_t scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size( - exec_queue, &m_int64, &n_int64, &lda_int64, 1, &group_sizes); - - Ty *scratchpad = sycl::malloc_device(scratchpad_size, exec_queue); - T **a_shared = sycl::malloc_shared(batch_size, exec_queue); - T **tau_shared = sycl::malloc_shared(batch_size, exec_queue); - exec_queue.memcpy(a_shared, a, batch_size * sizeof(T *)); - exec_queue.memcpy(tau_shared, tau, batch_size * sizeof(T *)).wait(); - - sycl::event e = oneapi::mkl::lapack::geqrf_batch( - exec_queue, &m_int64, &n_int64, (Ty **)a_shared, &lda_int64, (Ty **)tau_shared, 1, - &group_sizes, scratchpad, scratchpad_size); - - std::vector ptrs{scratchpad, a_shared, tau_shared}; - async_dpct_free(ptrs, {e}, exec_queue); -#endif -} - -/// Computes the Euclidean norm of a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void nrm2(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, void *result, library_data_t result_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, result_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::nrm2_impl(q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::nrm2_impl, float>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::nrm2_impl, double>( - q, n, x, incx, result); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::nrm2_impl( - q, n, x, incx, result); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes the dot product of two vectors. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dot(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the dot product of two vectors, conjugating the first vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in] y Input vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [out] result The result scalar. -/// \param [in] result_type Data type of the result. -inline void dotc(sycl::queue &q, int n, const void *x, library_data_t x_type, - int incx, const void *y, library_data_t y_type, int incy, - void *result, library_data_t result_type) { - detail::dotuc(q, n, x, x_type, incx, y, y_type, incy, result, - result_type); -} - -/// Computes the product of a vector by a scalar. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -inline void scal(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, void *x, library_data_t x_type, - int incx) { - std::uint64_t key = detail::get_type_combination_id(x_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_double): { - detail::scal_impl(q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float): { - detail::scal_impl, std::complex>(q, n, alpha, - x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double): { - detail::scal_impl, std::complex>( - q, n, alpha, x, incx); - break; - } - case detail::get_type_combination_id(library_data_t::real_half): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::scal_impl(q, n, &alaph_half, x, incx); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes a vector-scalar product and adds the result to a vector. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in] alpha The scale factor alpha. -/// \param [in] alpha_type The data type of alpha. -/// \param [in] x Input vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -inline void axpy(sycl::queue &q, int n, const void *alpha, - library_data_t alpha_type, const void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy) { - std::uint64_t key = detail::get_type_combination_id(x_type, alpha_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::axpy_impl(q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::axpy_impl, std::complex>( - q, n, alpha, x, incx, y, incy); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - sycl::half alaph_half(alpha_value); - detail::axpy_impl(q, n, &alaph_half, x, incx, y, incy); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Performs rotation of points in the plane. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] n Number of elements in vector x. -/// \param [in, out] x Input/Output vector x. -/// \param [in] x_type Data type of the vector x. -/// \param [in] incx Stride of vector x. -/// \param [in, out] y Input/Output vector y. -/// \param [in] y_type Data type of the vector y. -/// \param [in] incy Stride of vector y. -/// \param [in] c Scaling factor. -/// \param [in] s Scaling factor. -/// \param [in] cs_type Data type of the scaling factors. -inline void rot(sycl::queue &q, int n, void *x, library_data_t x_type, - int incx, void *y, library_data_t y_type, int incy, - const void *c, const void *s, library_data_t cs_type) { - std::uint64_t key = detail::get_type_combination_id(x_type, cs_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::real_float): { - detail::rot_impl, float, float>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::real_double): { - detail::rot_impl, double, double>(q, n, x, incx, y, incy, c, - s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float): { - detail::rot_impl, float, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double): { - detail::rot_impl, double, std::complex>(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_half, - library_data_t::real_half): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - case detail::get_type_combination_id(library_data_t::real_bfloat16, - library_data_t::real_bfloat16): { - detail::rot_impl(q, n, x, incx, y, incy, c, s); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes matrix-matrix product with general matrices. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] a_trans Specifies the operation applied to A. -/// \param [in] b_trans Specifies the operation applied to B. -/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. -/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. -/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). -/// \param [in] alpha Scaling factor for the matrix-matrix product. -/// \param [in] a Input matrix A. -/// \param [in] a_type Data type of the matrix A. -/// \param [in] lda Leading dimension of A. -/// \param [in] b Input matrix B. -/// \param [in] b_type Data type of the matrix B. -/// \param [in] ldb Leading dimension of B. -/// \param [in] beta Scaling factor for matrix C. -/// \param [in, out] c Input/Output matrix C. -/// \param [in] c_type Data type of the matrix C. -/// \param [in] ldc Leading dimension of C. -/// \param [in] scaling_type Data type of the scaling factors. -inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, const void *b, library_data_t b_type, int ldb, - const void *beta, void *c, library_data_t c_type, int ldc, - library_data_t scaling_type) { - bool matched = false; - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) { - scaling_type = library_data_t::complex_float; - } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, - a, lda, b, ldb, &beta_half, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// Computes a batch of matrix-matrix product with general matrices. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] a_trans Specifies the operation applied to A. -/// \param [in] b_trans Specifies the operation applied to B. -/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. -/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. -/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). -/// \param [in] alpha Scaling factor for the matrix-matrix product. -/// \param [in] a Input matrix A. -/// \param [in] a_type Data type of the matrix A. -/// \param [in] lda Leading dimension of A. -/// \param [in] b Input matrix B. -/// \param [in] b_type Data type of the matrix B. -/// \param [in] ldb Leading dimension of B. -/// \param [in] beta Scaling factor for matrix C. -/// \param [in, out] c Input/Output matrix C. -/// \param [in] c_type Data type of the matrix C. -/// \param [in] ldc Leading dimension of C. -/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. -/// \param [in] scaling_type Data type of the scaling factors. -inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) { -#ifdef DPCT_USM_LEVEL_NONE - throw std::runtime_error("this API is unsupported when USM level is none"); -#else - bool matched = false; - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) { - scaling_type = library_data_t::complex_float; - } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } -#ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } -#endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -#endif -} - -/// Computes a batch of matrix-matrix product with general matrices. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] a_trans Specifies the operation applied to A. -/// \param [in] b_trans Specifies the operation applied to B. -/// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. -/// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. -/// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). -/// \param [in] alpha Scaling factor for the matrix-matrix product. -/// \param [in] a Input matrix A. -/// \param [in] a_type Data type of the matrix A. -/// \param [in] lda Leading dimension of A. -/// \param [in] stride_a Stride between the different A matrices. -/// \param [in] b Input matrix B. -/// \param [in] b_type Data type of the matrix B. -/// \param [in] ldb Leading dimension of B. -/// \param [in] stride_b Stride between the different B matrices. -/// \param [in] beta Scaling factor for matrix C. -/// \param [in, out] c Input/Output matrix C. -/// \param [in] c_type Data type of the matrix C. -/// \param [in] ldc Leading dimension of C. -/// \param [in] stride_c Stride between the different C matrices. -/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. -/// \param [in] scaling_type Data type of the scaling factors. -inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, long long int stride_a, const void *b, - library_data_t b_type, int ldb, long long int stride_b, - const void *beta, void *c, library_data_t c_type, - int ldc, long long int stride_c, int batch_size, - library_data_t scaling_type) { - bool matched = false; - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) { - scaling_type = library_data_t::complex_float; - } else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } -#ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } -#endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, - &beta_half, c, ldc, stride_c, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -} - -/// This routines perform a special rank-k update of a symmetric matrix C by -/// general matrices A and B. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] uplo Specifies whether C's data is stored in its upper or lower triangle. -/// \param [in] trans Specifies the operation to apply. -/// \param [in] n The number of rows and columns in C. -/// \param [in] k The inner dimension of matrix multiplications. -/// \param [in] alpha Scaling factor for the rank-k update. -/// \param [in] a Input matrix A. -/// \param [in] lda Leading dimension of A. -/// \param [in] b Input matrix B. -/// \param [in] ldb Leading dimension of B. -/// \param [in] beta Scaling factor for the rank-k update. -/// \param [in, out] c Input/Output matrix C. -/// \param [in] ldc Leading dimension of C. -template -inline void syrk(sycl::queue &q, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, int n, int k, const T *alpha, - const T *a, int lda, const T *b, int ldb, const T *beta, T *c, - int ldc) { - detail::rk_impl(q, uplo, trans, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); -} - -/// This routines perform a special rank-k update of a Hermitian matrix C by -/// general matrices A and B. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] uplo Specifies whether C's data is stored in its upper or lower triangle. -/// \param [in] trans Specifies the operation to apply. -/// \param [in] n The number of rows and columns in C. -/// \param [in] k The inner dimension of matrix multiplications. -/// \param [in] alpha Scaling factor for the rank-k update. -/// \param [in] a Input matrix A. -/// \param [in] lda Leading dimension of A. -/// \param [in] b Input matrix B. -/// \param [in] ldb Leading dimension of B. -/// \param [in] beta Scaling factor for the rank-k update. -/// \param [in, out] c Input/Output matrix C. -/// \param [in] ldc Leading dimension of C. -template -inline void herk(sycl::queue &q, oneapi::mkl::uplo uplo, - oneapi::mkl::transpose trans, int n, int k, const T *alpha, - const T *a, int lda, const T *b, int ldb, const Tbeta *beta, - T *c, int ldc) { - detail::rk_impl(q, uplo, trans, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); -} - -/// This routine performs a group of trsm operations. Each trsm solves an -/// equation of the form op(A) * X = alpha * B or X * op(A) = alpha * B. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] left_right Specifies A multiplies X on the left or on the right. -/// \param [in] upper_lower Specifies A is upper or lower triangular. -/// \param [in] trans Specifies the operation applied to A. -/// \param [in] unit_diag Specifies whether A is unit triangular. -/// \param [in] m Number of rows of the B matrices. -/// \param [in] n Number of columns of the B matrices. -/// \param [in] alpha Scaling factor for the solutions. -/// \param [in] a Input matrices A. -/// \param [in] a_type Data type of the matrices A. -/// \param [in] lda Leading dimension of the matrices A. -/// \param [in, out] b Input and output matrices B. -/// \param [in] b_type Data type of the matrices B. -/// \param [in] ldb Leading dimension of the matrices B. -/// \param [in] batch_size Specifies the number of trsm operations to perform. -/// \param [in] scaling_type Data type of the scaling factors. -inline void trsm_batch(sycl::queue &q, oneapi::mkl::side left_right, - oneapi::mkl::uplo upper_lower, - oneapi::mkl::transpose trans, - oneapi::mkl::diag unit_diag, int m, int n, - const void *alpha, const void **a, library_data_t a_type, - int lda, void **b, library_data_t b_type, int ldb, - int batch_size, library_data_t scaling_type) { -#ifdef DPCT_USM_LEVEL_NONE - throw std::runtime_error("this API is unsupported when USM level is none"); -#else - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, scaling_type); - switch (key) { - case detail::get_type_combination_id(library_data_t::real_float, - library_data_t::real_float, - library_data_t::real_float): { - detail::trsm_batch_impl(q, left_right, upper_lower, - trans, unit_diag, m, n, alpha, - a, lda, b, ldb, batch_size); - break; - } - case detail::get_type_combination_id(library_data_t::real_double, - library_data_t::real_double, - library_data_t::real_double): { - detail::trsm_batch_impl( - q, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, - ldb, batch_size); - break; - } - case detail::get_type_combination_id(library_data_t::complex_float, - library_data_t::complex_float, - library_data_t::complex_float): { - detail::trsm_batch_impl, std::complex, - std::complex>(q, left_right, upper_lower, - trans, unit_diag, m, n, alpha, - a, lda, b, ldb, batch_size); - break; - } - case detail::get_type_combination_id(library_data_t::complex_double, - library_data_t::complex_double, - library_data_t::complex_double): { - detail::trsm_batch_impl, std::complex, - std::complex>(q, left_right, upper_lower, - trans, unit_diag, m, n, alpha, - a, lda, b, ldb, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } -#endif -} - -/// Computes a triangular matrix-general matrix product. -/// \param [in] q The queue where the routine should be executed. -/// \param [in] left_right Specifies A is on the left or right side of the -/// multiplication. -/// \param [in] upper_lower Specifies A is upper or lower triangular. -/// \param [in] trans Specifies the operation applied to A. -/// \param [in] unit_diag Specifies whether A is unit triangular. -/// \param [in] m Number of rows of B. -/// \param [in] n Number of columns of B. -/// \param [in] alpha Scaling factor for the matrix-matrix product. -/// \param [in] a Input matrices A. -/// \param [in] lda Leading dimension of the matrices A. -/// \param [in] b Input matrices B. -/// \param [in] ldb Leading dimension of the matrices B. -/// \param [out] c Output matrices C. -/// \param [in] ldc Leading dimension of the matrices C. -template -inline void trmm(sycl::queue &q, oneapi::mkl::side left_right, - oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, - oneapi::mkl::diag unit_diag, int m, int n, const T *alpha, - const T *a, int lda, const T *b, int ldb, T *c, int ldc) { - using Ty = typename DataType::T2; - auto alpha_val = dpct::get_value(alpha, q); - if (b != c) { - dpct::matrix_mem_copy(c, b, ldc, ldb, m, n, dpct::device_to_device, q); - } - auto data_a = detail::get_memory(a); - auto data_c = detail::get_memory(c); - oneapi::mkl::blas::column_major::trmm(q, left_right, upper_lower, trans, - unit_diag, m, n, alpha_val, data_a, lda, - data_c, ldc); -} - -} // namespace dpct -#endif // __DPCT_BLAS_UTILS_HPP__ diff --git a/dpct/ccl_utils.hpp b/dpct/ccl_utils.hpp deleted file mode 100644 index 07b3488c9..000000000 --- a/dpct/ccl_utils.hpp +++ /dev/null @@ -1,286 +0,0 @@ -//==---- ccl_utils.hpp----------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_CCL_UTILS_HPP__ -#define __DPCT_CCL_UTILS_HPP__ - -#include -#include -#include -#include - -#include "device.hpp" - -namespace dpct { -namespace ccl { -namespace detail { - -/// Get stored kvs with specified kvs address. -inline std::shared_ptr & -get_kvs(const oneapi::ccl::kvs::address_type &addr) { - struct hash { - std::size_t operator()(const oneapi::ccl::kvs::address_type &in) const { - return std::hash()(std::string_view(in.data(), in.size())); - } - }; - static std::unordered_map, hash> - kvs_map; - return kvs_map[addr]; -} - -/// Help class to init ccl environment. -class ccl_init_helper { -public: - ccl_init_helper() { oneapi::ccl::init(); } -}; - -} // namespace detail - -/// Get concatenated library version as an integer. -static inline int get_version() { - oneapi::ccl::init(); - auto ver = oneapi::ccl::get_library_version(); - return ver.major * 10000 + ver.minor * 100 + ver.update; -} - -/// Create main kvs and return its address. -static inline oneapi::ccl::kvs::address_type create_kvs_address() { - oneapi::ccl::init(); - auto ptr = oneapi::ccl::create_main_kvs(); - auto addr = ptr->get_address(); - detail::get_kvs(addr) = ptr; - return addr; -} - -/// Get stored kvs with /p addr if exist. Otherwise, create kvs with /p addr. -static inline std::shared_ptr -create_kvs(const oneapi::ccl::kvs::address_type &addr) { - oneapi::ccl::init(); - auto &ptr = detail::get_kvs(addr); - if (!ptr) - ptr = oneapi::ccl::create_kvs(addr); - return ptr; -} - -/// dpct communicator extension -class communicator_wrapper : public dpct::ccl::detail::ccl_init_helper { -public: - communicator_wrapper( - int size, int rank, oneapi::ccl::kvs::address_type id, - const oneapi::ccl::comm_attr &attr = oneapi::ccl::default_comm_attr) - : _device_comm(oneapi::ccl::create_device( - static_cast(dpct::get_current_device()))), - _context_comm(oneapi::ccl::create_context(dpct::get_default_context())), - _comm(oneapi::ccl::create_communicator( - size, rank, _device_comm, _context_comm, dpct::ccl::create_kvs(id), - attr)) { - _queue_init = false; - _ccl_stream_ptr = nullptr; - } - - ~communicator_wrapper() { - delete _ccl_stream_ptr; - }; - - /// Return the rank in a oneapi::ccl::communicator - /// \returns The rank corresponding to communicator object - int rank() const { - return _comm.rank(); - } - - /// Retrieves the number of rank in oneapi::ccl::communicator - /// \returns The number of the ranks - int size() const { - return _comm.size(); - } - - /// Return underlying native device, which was used in oneapi::ccl::communicator - sycl::device get_device() const { - return _comm.get_device().get_native(); - } - - /// \brief allreduce is a collective communication operation that performs the global reduction operation - /// on values from all ranks of communicator and distributes the result back to all ranks. - /// \param sendbuff the buffer with @c count elements of @c dtype that stores local data to be reduced - /// \param recvbuff [out] the buffer to store reduced result, must have the same dimension as @c sendbuff - /// \param count the number of elements of type @c dtype in @c sendbuff and @c recvbuff - /// \param dtype the datatype of elements in @c sendbuff and @c recvbuff - /// \param rtype the type of the reduction operation to be applied - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void allreduce(const void *sendbuff, void *recvbuff, size_t count, - oneapi::ccl::datatype dtype, oneapi::ccl::reduction rtype, - sycl::queue *queue_ptr) { - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::allreduce(sendbuff, recvbuff, count, dtype, rtype, - _comm, stream); - }, - queue_ptr); - } - - /// \brief reduce is a collective communication operation that performs the - /// global reduction operation on values from all ranks of the communicator - /// and returns the result to the root rank. - /// \param sendbuff the buffer with @c count elements of @c dtype that stores - /// local data to be reduced - /// \param recvbuff [out] the buffer to store reduced result, - /// must have the same dimension as @c sendbuff - /// \param count the number of elements of type @c dtype in @c sendbuff and @c recvbuff - /// \param dtype the datatype of elements in @c sendbuff and @c recvbuff - /// \param root the rank that gets the result of reduction - /// \param rtype the type of the reduction operation to be applied - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void reduce(const void *sendbuff, void *recvbuff, size_t count, - oneapi::ccl::datatype dtype, oneapi::ccl::reduction rtype, - int root, sycl::queue *queue_ptr) { - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::reduce(sendbuff, recvbuff, count, dtype, rtype, - root, _comm, stream); - }, - queue_ptr); - } - - /// \brief broadcast is a collective communication operation that broadcasts data - /// from one rank of communicator (denoted as root) to all other ranks. - /// Only support in-place operation - /// \param sendbuff the buffer with @c count elements of @c dtype that stores - /// local data to be reduced - /// \param recvbuff [out] the buffer to store reduced result - /// \param count the number of elements of type @c dtype in @c buf - /// \param dtype thedatatype of elements in @c buf - /// \param root the rank that broadcasts @c buf - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void broadcast(void *sendbuff, void *recvbuff, size_t count, - oneapi::ccl::datatype dtype, int root, - sycl::queue *queue_ptr) { - if (sendbuff != recvbuff) { - throw std::runtime_error( - "oneCCL broadcast only support in-place operation. " - "sendbuff and recvbuff must be same."); - return; - } - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::broadcast(recvbuff, count, dtype, root, _comm, - stream); - }, - queue_ptr); - } - - /// \brief reduce_scatter is a collective communication operation that performs the global reduction operation - /// on values from all ranks of the communicator and scatters the result in blocks back to all ranks. - /// \param sendbuff the buffer with @c count elements of @c dtype that stores local data to be reduced - /// \param recvbuff [out] the buffer to store reduced result, must have the same dimension as @c sendbuff - /// \param recv_count the number of elements of type @c dtype in receive block - /// \param dtype the datatype of elements in @c sendbuff and @c recvbuff - /// \param rtype the type of the reduction operation to be applied - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void reduce_scatter(const void *sendbuff, void *recvbuff, size_t recv_count, - oneapi::ccl::datatype dtype, oneapi::ccl::reduction rtype, - sycl::queue *queue_ptr) { - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::reduce_scatter(sendbuff, recvbuff, recv_count, - dtype, rtype, _comm, stream); - }, - queue_ptr); - } - - /// \brief send is a pt2pt communication operation that sends data from one rank of communicator. - /// \param sendbuff the buffer with @c count elements of @c dtype serves as send buffer for root - /// \param count the number of elements of type @c dtype in @c sendbuff - /// \param dtype the datatype of elements in @c sendbuff - /// \param peer the rank that receives @c sendbuff - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void send(void *sendbuff, size_t count, oneapi::ccl::datatype dtype, int peer, - sycl::queue *queue_ptr) { - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::send(sendbuff, count, dtype, peer, _comm, stream); - }, - queue_ptr); - } - - /// \brief recv is a pt2pt communication operation that sends data from one rank of communicator. - /// \param recvbuff the buffer with @c count elements of @c dtype serves as receive buffer - /// \param count the number of elements of type @c dtype in @c recvbuff - /// \param dtype the datatype of elements in @c recvbuff - /// \param peer the rank that receives @c recvbuff - /// \param queue_ptr a sycl::queue ptr associated with the operation - /// \return @ref void - void recv(void *recvbuff, size_t count, oneapi::ccl::datatype dtype, int peer, - sycl::queue *queue_ptr) { - call_func_wrapper( - [=](const oneapi::ccl::stream &stream) { - return oneapi::ccl::recv(recvbuff, count, dtype, peer, _comm, stream); - }, - queue_ptr); - } - -private: - oneapi::ccl::device _device_comm; - oneapi::ccl::context _context_comm; - oneapi::ccl::communicator _comm; - sycl::queue _queue; - bool _queue_init; - oneapi::ccl::stream *_ccl_stream_ptr; - - template - void call_func_wrapper(Fn func, sycl::queue *qptr) { - if (_queue_init && *qptr != _queue) { - call_func_async(func, qptr); - } else { - if(!_queue_init) { - _queue = *qptr; - _queue_init = true; - _ccl_stream_ptr = new oneapi::ccl::stream(oneapi::ccl::create_stream(_queue)); - } - std::invoke(func, *_ccl_stream_ptr); - } - } - - class call_func_async { - sycl::queue *_q_ptr; - struct call_async_impl { - oneapi::ccl::stream _ccl_stream_impl; - oneapi::ccl::event _ccl_event_impl; - template - explicit call_async_impl(Fn func, sycl::queue *qptr) - : _ccl_stream_impl(oneapi::ccl::create_stream(*qptr)), - _ccl_event_impl(std::invoke(func, _ccl_stream_impl)) {} - }; - call_async_impl *_imp; - - public: - template - explicit call_func_async(Fn func, sycl::queue *qptr) - : _q_ptr(qptr), - _imp(new call_async_impl(func, qptr)) {} - ~call_func_async() { - _q_ptr->submit([&](sycl::handler &cgh) - { cgh.host_task([=] - { - _imp->_ccl_event_impl.wait(); - delete _imp; }); }); - } - }; -}; - -typedef dpct::ccl::communicator_wrapper *comm_ptr; - -} // namespace ccl -} // namespace dpct - -#endif // __DPCT_CCL_UTILS_HPP__ \ No newline at end of file diff --git a/dpct/device.hpp b/dpct/device.hpp deleted file mode 100644 index 729ebf625..000000000 --- a/dpct/device.hpp +++ /dev/null @@ -1,781 +0,0 @@ -//==---- device.hpp -------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_DEVICE_HPP__ -#define __DPCT_DEVICE_HPP__ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if defined(__linux__) -#include -#include -#endif -#if defined(_WIN64) -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#endif - -namespace dpct { -namespace detail { -static void get_version(const sycl::device &dev, int &major, int &minor) { - // Version string has the following format: - // a. OpenCL - // b. - std::string ver; - ver = dev.get_info(); - std::string::size_type i = 0; - while (i < ver.size()) { - if (isdigit(ver[i])) - break; - i++; - } - major = std::stoi(&(ver[i])); - while (i < ver.size()) { - if (ver[i] == '.') - break; - i++; - } - i++; - minor = std::stoi(&(ver[i])); -} -} // namespace detail - -/// SYCL default exception handler -inline auto exception_handler = [](sycl::exception_list exceptions) { - for (std::exception_ptr const &e : exceptions) { - try { - std::rethrow_exception(e); - } catch (sycl::exception const &e) { - std::cerr << "Caught asynchronous SYCL exception:" << std::endl - << e.what() << std::endl - << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - } - } -}; - -typedef sycl::event *event_ptr; - -typedef sycl::queue *queue_ptr; - -typedef char *device_ptr; - -/// Destroy \p event pointed memory. -/// -/// \param event Pointer to the sycl::event address. -static void destroy_event(event_ptr event) { - delete event; -} - -class device_info { -public: - // get interface - const char *get_name() const { return _name; } - char *get_name() { return _name; } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() const { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else { - return _max_work_item_sizes_i; - } - } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else { - return _max_work_item_sizes_i; - } - } - bool get_host_unified_memory() const { return _host_unified_memory; } - int get_major_version() const { return _major; } - int get_minor_version() const { return _minor; } - int get_integrated() const { return _integrated; } - int get_max_clock_frequency() const { return _frequency; } - int get_max_compute_units() const { return _max_compute_units; } - int get_max_work_group_size() const { return _max_work_group_size; } - int get_max_sub_group_size() const { return _max_sub_group_size; } - int get_max_work_items_per_compute_unit() const { - return _max_work_items_per_compute_unit; - } - int get_max_register_size_per_work_group() const { - return _max_register_size_per_work_group; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() const { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - size_t get_global_mem_size() const { return _global_mem_size; } - size_t get_local_mem_size() const { return _local_mem_size; } - /// Returns the maximum clock rate of device's global memory in kHz. If - /// compiler does not support this API then returns default value 3200000 kHz. - unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } - /// Returns the maximum bus width between device and memory in bits. If - /// compiler does not support this API then returns default value 64 bits. - unsigned int get_memory_bus_width() const { return _memory_bus_width; } - uint32_t get_device_id() const { return _device_id; } - std::array get_uuid() const { return _uuid; } - /// Returns global memory cache size in bytes. - unsigned int get_global_mem_cache_size() const { - return _global_mem_cache_size; - } - - // set interface - void set_name(const char* name) { - size_t length = strlen(name); - if (length < 256) { - std::memcpy(_name, name, length + 1); - } else { - std::memcpy(_name, name, 255); - _name[255] = '\0'; - } - } - void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) { - for (int i = 0; i < 3; ++i) - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - [[deprecated]] void - set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) { - for (int i = 0; i < 3; ++i) { - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - } - void set_host_unified_memory(bool host_unified_memory) { - _host_unified_memory = host_unified_memory; - } - void set_major_version(int major) { _major = major; } - void set_minor_version(int minor) { _minor = minor; } - void set_integrated(int integrated) { _integrated = integrated; } - void set_max_clock_frequency(int frequency) { _frequency = frequency; } - void set_max_compute_units(int max_compute_units) { - _max_compute_units = max_compute_units; - } - void set_global_mem_size(size_t global_mem_size) { - _global_mem_size = global_mem_size; - } - void set_local_mem_size(size_t local_mem_size) { - _local_mem_size = local_mem_size; - } - void set_max_work_group_size(int max_work_group_size) { - _max_work_group_size = max_work_group_size; - } - void set_max_sub_group_size(int max_sub_group_size) { - _max_sub_group_size = max_sub_group_size; - } - void - set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) { - _max_work_items_per_compute_unit = max_work_items_per_compute_unit; - } - void set_max_nd_range_size(int max_nd_range_size[]) { - for (int i = 0; i < 3; i++) { - _max_nd_range_size[i] = max_nd_range_size[i]; - _max_nd_range_size_i[i] = max_nd_range_size[i]; - } - } - void set_memory_clock_rate(unsigned int memory_clock_rate) { - _memory_clock_rate = memory_clock_rate; - } - void set_memory_bus_width(unsigned int memory_bus_width) { - _memory_bus_width = memory_bus_width; - } - void - set_max_register_size_per_work_group(int max_register_size_per_work_group) { - _max_register_size_per_work_group = max_register_size_per_work_group; - } - void set_device_id(uint32_t device_id) { - _device_id = device_id; - } - void set_uuid(std::array uuid) { - _uuid = std::move(uuid); - } - void set_global_mem_cache_size(unsigned int global_mem_cache_size) { - _global_mem_cache_size = global_mem_cache_size; - } - -private: - char _name[256]; - int _max_work_item_sizes_i[3]; - bool _host_unified_memory = false; - int _major; - int _minor; - int _integrated = 0; - int _frequency; - // Set estimated value 3200000 kHz as default value. - unsigned int _memory_clock_rate = 3200000; - // Set estimated value 64 bits as default value. - unsigned int _memory_bus_width = 64; - unsigned int _global_mem_cache_size; - int _max_compute_units; - int _max_work_group_size; - int _max_sub_group_size; - int _max_work_items_per_compute_unit; - int _max_register_size_per_work_group; - size_t _global_mem_size; - size_t _local_mem_size; - size_t _max_nd_range_size[3]; - int _max_nd_range_size_i[3]; - uint32_t _device_id; - std::array _uuid; -}; - -static int get_major_version(const sycl::device &dev) { - int major, minor; - detail::get_version(dev, major, minor); - return major; -} - -static int get_minor_version(const sycl::device &dev) { - int major, minor; - detail::get_version(dev, major, minor); - return minor; -} - -static void get_device_info(device_info &out, const sycl::device &dev) { - device_info prop; - prop.set_name(dev.get_info().c_str()); - - int major, minor; - detail::get_version(dev, major, minor); - prop.set_major_version(major); - prop.set_minor_version(minor); - - prop.set_max_work_item_sizes( -#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) - // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes - // is an enum class element - dev.get_info()); -#else - // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by - // an int - dev.get_info>()); -#endif - prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); - - prop.set_max_clock_frequency( - dev.get_info() * 1000); - - prop.set_max_compute_units( - dev.get_info()); - prop.set_max_work_group_size( - dev.get_info()); - prop.set_global_mem_size(dev.get_info()); - prop.set_local_mem_size(dev.get_info()); - -#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) - if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) { - unsigned int tmp = - dev.get_info(); - if (tmp != 0) - prop.set_memory_clock_rate(1000 * tmp); - } - if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) { - prop.set_memory_bus_width( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_id)) { - prop.set_device_id( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) { - prop.set_uuid(dev.get_info()); - } -#elif defined(_MSC_VER) && !defined(__clang__) -#pragma message("get_device_info: querying memory_clock_rate and \ -memory_bus_width are not supported by the compiler used. \ -Use 3200000 kHz as memory_clock_rate default value. \ -Use 64 bits as memory_bus_width default value.") -#else -#warning "get_device_info: querying memory_clock_rate and \ -memory_bus_width are not supported by the compiler used. \ -Use 3200000 kHz as memory_clock_rate default value. \ -Use 64 bits as memory_bus_width default value." -#endif - - size_t max_sub_group_size = 1; - std::vector sub_group_sizes = - dev.get_info(); - - for (const auto &sub_group_size : sub_group_sizes) { - if (max_sub_group_size < sub_group_size) - max_sub_group_size = sub_group_size; - } - - prop.set_max_sub_group_size(max_sub_group_size); - - prop.set_max_work_items_per_compute_unit( - dev.get_info()); - int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; - prop.set_max_nd_range_size(max_nd_range_size); - - // Estimates max register size per work group, feel free to update the value - // according to device properties. - prop.set_max_register_size_per_work_group(65536); - - prop.set_global_mem_cache_size( - dev.get_info()); - out = prop; -} - -/// dpct device extension -class device_ext : public sycl::device { - typedef std::mutex mutex_type; - -public: - device_ext() : sycl::device(), _ctx(*this) {} - ~device_ext() { - std::lock_guard lock(m_mutex); - clear_queues(); - } - device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) { - std::lock_guard lock(m_mutex); - init_queues(); - } - - int is_native_atomic_supported() { return 0; } - int get_major_version() const { - return dpct::get_major_version(*this); - } - - int get_minor_version() const { - return dpct::get_minor_version(*this); - } - - int get_max_compute_units() const { - return get_device_info().get_max_compute_units(); - } - - /// Return the maximum clock frequency of this device in KHz. - int get_max_clock_frequency() const { - return get_device_info().get_max_clock_frequency(); - } - - int get_integrated() const { return get_device_info().get_integrated(); } - - int get_max_sub_group_size() const { - return get_device_info().get_max_sub_group_size(); - } - - int get_max_register_size_per_work_group() const { - return get_device_info().get_max_register_size_per_work_group(); - } - - int get_max_work_group_size() const { - return get_device_info().get_max_work_group_size(); - } - - int get_mem_base_addr_align() const { - return get_info(); - } - - size_t get_global_mem_size() const { - return get_device_info().get_global_mem_size(); - } - - /// Get the number of bytes of free and total memory on the SYCL device. - /// \param [out] free_memory The number of bytes of free memory on the SYCL device. - /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) { -#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) - if (!has(sycl::aspect::ext_intel_free_memory)) { - std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; - free_memory = 0; - } else { - free_memory = get_info(); - } -#else - std::cerr << "get_memory_info: ext_intel_free_memory is not supported." << std::endl; - free_memory = 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma message("Querying the number of bytes of free memory is not supported") -#else -#warning "Querying the number of bytes of free memory is not supported" -#endif -#endif - total_memory = get_device_info().get_global_mem_size(); - } - - void get_device_info(device_info &out) const { - dpct::get_device_info(out, *this); - } - - device_info get_device_info() const { - device_info prop; - dpct::get_device_info(prop, *this); - return prop; - } - - void reset() { - std::lock_guard lock(m_mutex); - clear_queues(); - init_queues(); - } - - sycl::queue &in_order_queue() { return *_q_in_order; } - - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } - - sycl::queue &default_queue() { -#ifdef DPCT_USM_LEVEL_NONE - return out_of_order_queue(); -#else - return in_order_queue(); -#endif // DPCT_USM_LEVEL_NONE - } - - void queues_wait_and_throw() { - std::unique_lock lock(m_mutex); - std::vector> current_queues( - _queues); - lock.unlock(); - for (const auto &q : current_queues) { - q->wait_and_throw(); - } - // Guard the destruct of current_queues to make sure the ref count is safe. - lock.lock(); - } - - sycl::queue *create_queue(bool enable_exception_handler = false) { -#ifdef DPCT_USM_LEVEL_NONE - return create_out_of_order_queue(enable_exception_handler); -#else - return create_in_order_queue(enable_exception_handler); -#endif // DPCT_USM_LEVEL_NONE - } - - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler); - } - - void destroy_queue(sycl::queue *&queue) { - std::lock_guard lock(m_mutex); - _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr &q) -> bool { - return q.get() == queue; - }), - _queues.end()); - queue = nullptr; - } - void set_saved_queue(sycl::queue* q) { - std::lock_guard lock(m_mutex); - _saved_queue = q; - } - sycl::queue *get_saved_queue() const { - std::lock_guard lock(m_mutex); - return _saved_queue; - } - sycl::context get_context() const { return _ctx; } - -private: - void clear_queues() { - _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; - } - - void init_queues() { - _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); - _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); - } - - /// Caller should acquire resource \p m_mutex before calling this function. - template - sycl::queue *create_queue_impl(bool enable_exception_handler, - Properties... properties) { - sycl::async_handler eh = {}; - if (enable_exception_handler) { - eh = exception_handler; - } - _queues.push_back(std::make_shared( - _ctx, *this, eh, - sycl::property_list( -#ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), -#endif - properties...))); - - return _queues.back().get(); - } - - void get_version(int &major, int &minor) const { - detail::get_version(*this, major, minor); - } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - sycl::context _ctx; - std::vector> _queues; - mutable mutex_type m_mutex; -}; - -static inline unsigned int get_tid() { -#if defined(__linux__) - return syscall(SYS_gettid); -#elif defined(_WIN64) - return GetCurrentThreadId(); -#else -#error "Only support Windows and Linux." -#endif -} - -/// device manager -class dev_mgr { -public: - device_ext ¤t_device() { - unsigned int dev_id=current_device_id(); - check_id(dev_id); - return *_devs[dev_id]; - } - device_ext &cpu_device() const { - std::lock_guard lock(m_mutex); - if (_cpu_device == -1) { - throw std::runtime_error("no valid cpu device"); - } else { - return *_devs[_cpu_device]; - } - } - device_ext &get_device(unsigned int id) const { - std::lock_guard lock(m_mutex); - check_id(id); - return *_devs[id]; - } - unsigned int current_device_id() const { - std::lock_guard lock(m_mutex); - auto it=_thread2dev_map.find(get_tid()); - if(it != _thread2dev_map.end()) - return it->second; - return DEFAULT_DEVICE_ID; - } - -/// Select device with a device ID. -/// \param [in] id The id of the device which can -/// be obtained through get_device_id(const sycl::device). - void select_device(unsigned int id) { - std::lock_guard lock(m_mutex); - check_id(id); - _thread2dev_map[get_tid()]=id; - } - unsigned int device_count() { return _devs.size(); } - - unsigned int get_device_id(const sycl::device &dev) { - unsigned int id = 0; - for(auto dev_item : _devs) { - if (*dev_item == dev) { - break; - } - id++; - } - return id; - } - - template - std::enable_if_t< - std::is_invocable_r_v> - select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { - sycl::device selected_device = sycl::device(selector); - unsigned int selected_device_id = get_device_id(selected_device); - select_device(selected_device_id); - } - - /// Returns the instance of device manager singleton. - static dev_mgr &instance() { - static dev_mgr d_m; - return d_m; - } - dev_mgr(const dev_mgr &) = delete; - dev_mgr &operator=(const dev_mgr &) = delete; - dev_mgr(dev_mgr &&) = delete; - dev_mgr &operator=(dev_mgr &&) = delete; - -private: - mutable std::recursive_mutex m_mutex; - dev_mgr() { - sycl::device default_device = - sycl::device(sycl::default_selector_v); - _devs.push_back(std::make_shared(default_device)); - - std::vector sycl_all_devs = - sycl::device::get_devices(sycl::info::device_type::all); - // Collect other devices except for the default device. - if (default_device.is_cpu()) - _cpu_device = 0; - for (auto &dev : sycl_all_devs) { - if (dev == default_device) { - continue; - } - _devs.push_back(std::make_shared(dev)); - if (_cpu_device == -1 && dev.is_cpu()) { - _cpu_device = _devs.size() - 1; - } - } - } - void check_id(unsigned int id) const { - if (id >= _devs.size()) { - throw std::runtime_error("invalid device id"); - } - } - std::vector> _devs; - /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current - /// thread id in _thread2dev_map, which means default device should be used - /// for the current thread. - const unsigned int DEFAULT_DEVICE_ID = 0; - /// thread-id to device-id map. - std::map _thread2dev_map; - int _cpu_device = -1; -}; - -/// Util function to get the default queue of current selected device depends on -/// the USM config. Return the default out-of-ordered queue when USM-none is -/// enabled, otherwise return the default in-ordered queue. -static inline sycl::queue &get_default_queue() { - return dev_mgr::instance().current_device().default_queue(); -} - -/// Util function to get the default in-ordered queue of current device in -/// dpct device manager. -static inline sycl::queue &get_in_order_queue() { - return dev_mgr::instance().current_device().in_order_queue(); -} - -/// Util function to get the default out-of-ordered queue of current device in -/// dpct device manager. -static inline sycl::queue &get_out_of_order_queue() { - return dev_mgr::instance().current_device().out_of_order_queue(); -} - -/// Util function to get the id of current device in -/// dpct device manager. -static inline unsigned int get_current_device_id() { - return dev_mgr::instance().current_device_id(); -} - -/// Util function to get the current device. -static inline device_ext &get_current_device() { - return dev_mgr::instance().current_device(); -} - -/// Util function to get a device by id. -static inline device_ext &get_device(unsigned int id) { - return dev_mgr::instance().get_device(id); -} - -/// Util function to get the context of the default queue of current -/// device in dpct device manager. -static inline sycl::context get_default_context() { - return dpct::get_current_device().get_context(); -} - -/// Util function to get a CPU device. -static inline device_ext &cpu_device() { - return dev_mgr::instance().cpu_device(); -} - -static inline unsigned int select_device(unsigned int id) { - dev_mgr::instance().select_device(id); - return id; -} - -template -static inline std::enable_if_t< - std::is_invocable_r_v> -select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { - dev_mgr::instance().select_device(selector); -} - -static inline unsigned int get_device_id(const sycl::device &dev){ - return dev_mgr::instance().get_device_id(dev); -} - -/// Util function to check whether a device supports some kinds of sycl::aspect. -inline void -has_capability_or_fail(const sycl::device &dev, - const std::initializer_list &props) { - for (const auto &it : props) { - if (dev.has(it)) - continue; - switch (it) { - case sycl::aspect::fp64: - throw std::runtime_error("'double' is not supported in '" + - dev.get_info() + - "' device"); - break; - case sycl::aspect::fp16: - throw std::runtime_error("'half' is not supported in '" + - dev.get_info() + - "' device"); - break; - default: -#define __SYCL_ASPECT(ASPECT, ID) \ - case sycl::aspect::ASPECT: \ - return #ASPECT; -#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) -#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) - auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string { - switch (AspectNum) { -#include -#include - default: - return "unknown aspect"; - } - }; -#undef __SYCL_ASPECT_DEPRECATED_ALIAS -#undef __SYCL_ASPECT_DEPRECATED -#undef __SYCL_ASPECT - throw std::runtime_error( - "'" + getAspectNameStr(it) + "' is not supported in '" + - dev.get_info() + "' device"); - } - break; - } -} -} // namespace dpct - -#endif // __DPCT_DEVICE_HPP__ diff --git a/dpct/dnnl_utils.hpp b/dpct/dnnl_utils.hpp deleted file mode 100644 index caf5a768b..000000000 --- a/dpct/dnnl_utils.hpp +++ /dev/null @@ -1,4921 +0,0 @@ -//==---- dnnl_utils.hpp ---------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_DNNL_UTILS_HPP__ -#define __DPCT_DNNL_UTILS_HPP__ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "memory.hpp" -#include "device.hpp" -#include "lib_common_utils.hpp" - -namespace dpct { -namespace dnnl { -/// Get concatenated library version as an integer. -static inline size_t get_version() { - const ::dnnl::version_t *ver = ::dnnl::version(); - return ver->major * 1000 + ver->minor * 100 + ver->patch; -} -class engine_ext; -typedef oneapi::mkl::rng::philox4x32x10 rng_engine_t; -/// An enum class representing memory layout. Used by -/// memory_desc_ext to create a memory with pre-defined layout. -enum class memory_format_tag { nchw, nhwc, nchw_blocked }; - -/// An enum class representing RNN data memory layout. Used by -/// memory_desc_ext to create a memory with pre-defined layout. -enum class rnn_memory_format_tag { tnc, ntc }; - -/// A class holding the description of an N-dimensions memory. -class memory_desc_ext { - ::dnnl::memory::desc _desc; -public: - /// Convert dpct::library_data_t to dnnl::memory::data_type. - static ::dnnl::memory::data_type to_dnnl_data_type(dpct::library_data_t dt); - /// Convert dnnl::memory::data_type to dpct::library_data_t. - static dpct::library_data_t - to_dpct_library_data_t(::dnnl::memory::data_type dt, unsigned block_size); - /// Convert dpct::dnnl::memory_format_tag to dnnl::memory::format_tag. - static ::dnnl::memory::format_tag to_dnnl_format_tag(dpct::library_data_t dt, - memory_format_tag tag); - memory_desc_ext() = default; - memory_desc_ext(::dnnl::memory::desc &desc) : _desc(desc) {} - memory_desc_ext(::dnnl::memory::desc &&desc) : _desc(std::move(desc)) {} - /// Setting a 4D memory with given parameters. - /// \param [in] tag Format tag. - /// \param [in] dt Data type. - /// \param [in] n Number of images. - /// \param [in] c Number of channels. - /// \param [in] h Height of images. - /// \param [in] w Width of images. - void set(memory_format_tag tag, dpct::library_data_t dt, int n, int c, int h, - int w); - /// Setting a 3D RNN data memory with given parameters. - /// \param [in] tag RNN data format tag. - /// \param [in] dt Data type. - /// \param [in] t Number of sequence length. - /// \param [in] n Number of batch. - /// \param [in] c Height of input channel. - void set(rnn_memory_format_tag tag, dpct::library_data_t dt, int t, int n, int c); - /// Setting a 4D memory with given parameters. - /// \param [in] dt Data type. - /// \param [in] n Number of images. - /// \param [in] c Number of channels. - /// \param [in] h Height of images. - /// \param [in] w Width of images. - /// \param [in] n_stride Stride between two continuous images. - /// \param [in] c_stride Stride between two continuous channels. - /// \param [in] h_stride Stride between two continuous rows. - /// \param [in] w_stride Stride between two continuous columns. - void set(dpct::library_data_t dt, int n, int c, int h, int w, int n_stride, - int c_stride, int h_stride, int w_stride); - /// Setting a ND memory with given parameters. - /// \param [in] dt Data type. - /// \param [in] ndims Dimension of the memory. - /// \param [in] dims Array of dimension ndims that contain the size of each - /// memory dimension. \param [in] strides Array of dimension ndims that - /// contain the stride of each memory dimension. - void set(dpct::library_data_t dt, int ndims, const int dims[], - const int strides[]); - /// Setting a ND memory with given parameters. - /// \param [in] tag Format tag. - /// \param [in] dt Data type. - /// \param [in] ndims Dimension of the memory. - /// \param [in] dims Array of dimension ndims that contain the size of each - /// memory dimension. - void set(memory_format_tag tag, dpct::library_data_t dt, int ndims, - const int dims[]); - /// Getting a ::dnnl::memory::desc from a memory_desc_ext. - /// \returns The ::dnnl::memory::desc. - const ::dnnl::memory::desc &get_desc() const { return _desc; } - /// Setting holding desc with given dnnl memory descriptor. - void set_desc(::dnnl::memory::desc desc) { _desc = desc; } - /// Getting a size of a memory_desc_ext in bytes. - /// \returns The size. - size_t get_size() const { return _desc.get_size(); } - /// Getting parameters from a 4D memory. - /// \param [out] dt Data type. - /// \param [out] n Number of images. - /// \param [out] c Number of channels. - /// \param [out] h Height of images. - /// \param [out] w Width of images. - /// \param [out] n_stride Stride between two continuous images. - /// \param [out] c_stride Stride between two continuous channels. - /// \param [out] h_stride Stride between two continuous rows. - /// \param [out] w_stride Stride between two continuous columns. - void get(dpct::library_data_t *dt, int *n, int *c, int *h, int *w, - int *n_stride, int *c_stride, int *h_stride, int *w_stride) const; - /// Getting parameters from a 4D memory. - /// \param [out] dt Data type. - /// \param [out] tag Format tag. - /// \param [out] n Number of images. - /// \param [out] c Number of channels. - /// \param [out] h Height of images. - /// \param [out] w Width of images. - void get(dpct::library_data_t *dt, memory_format_tag *tag, int *n, int *c, - int *h, int *w) const; - /// Getting parameters from a 3D RNN data memory. - /// \param [out] dt Data type. - /// \param [out] tag RNN data format tag. - /// \param [out] t Number of sequence length. - /// \param [out] n Number of batch. - /// \param [out] c Height of input channel. - void get(dpct::library_data_t *dt, rnn_memory_format_tag *tag, int *t, int *n, - int *c) const; - /// Getting parameters from a ND memory. - /// \param [in] requested_ndims Requested number of dimensions to get from a - /// given memory descriptor. - /// \param [out] dt Data type. - /// \param [out] ndims Dimension of the memory. - /// \param [out] dims Array of dimension requested_ndims that contain the - /// size of each memory dimension. - /// \param [out] strides Array of dimension requested_ndims that contain the - /// stride of each memory dimension. - void get(int requested_ndims, dpct::library_data_t *dt, int *ndims, - int dims[], int strides[]) const; - /// Getting parameters from a ND memory. - /// \param [in] requested_ndims Requested number of dimensions to get from a - /// given memory descriptor. - /// \param [out] dt Data type. - /// \param [out] tag Format tag. - /// \param [out] ndims Dimension of the memory. - /// \param [out] dims Array of dimension requested_ndims that contain the - /// size of each memory dimension. - void get(int requested_ndims, dpct::library_data_t *dt, - memory_format_tag *tag, int *ndims, int dims[]) const; - /// Getting dims from a ND memory. - /// \return The dims. - std::vector get_dims() const { return _desc.get_dims(); } - /// Getting strides from a ND memory. - /// \return The strides. - std::vector get_strides() const { - return _desc.get_strides(); - } - /// Getting element num from a ND memory. - /// \return The element number. - size_t get_element_num() const { - auto dims = _desc.get_dims(); - if (dims.empty()) { - return 0; - } - size_t result = 1; - for (auto &dim : dims) { - result *= dim; - } - return result; - } - - operator bool() const { - return bool(_desc); - } - - memory_desc_ext &operator=(std::nullptr_t) { - _desc.reset(nullptr); - return *this; - } -}; - -/// A class holding description for an activation operation. -class activation_desc { - ::dnnl::algorithm _alg; - float _alpha; - float _beta; - -public: - /// Setting an activation descriptor with given parameters. - /// \param [in] alg Activation algorithm. - /// \param [in] alpha Value of alpha parameter. - void set(::dnnl::algorithm alg, float alpha) { - _alg = alg; - if(alg == ::dnnl::algorithm::eltwise_clip) { - _alpha = 0; - _beta = alpha; - } else { - _alpha = alpha; - } - } - /// Getting parameters form an activation descriptor. - /// \param [out] alg Activation algorithm. - /// \param [out] alpha Value of alpha parameter. - void get(::dnnl::algorithm *alg, float *alpha) const { - *alg = _alg; - if(_alg == ::dnnl::algorithm::eltwise_clip) { - *alpha = _beta; - } else { - *alpha = _alpha; - } - } - /// Setting the alpha parameter of an activation descriptor. - /// \param [in] alpha Value of alpha parameter. - void set_alpha(float alpha) { _alpha = alpha; } - /// Setting the beta parameter of an activation descriptor. - /// \param [in] beta Value of beta parameter. - void set_beta(float beta) { _beta = beta; } - /// Setting the algorithm parameter of an activation descriptor. - /// \param [in] alg Activation algorithm. - void set_algorithm(::dnnl::algorithm alg) { _alg = alg; } - /// Getting the alpha parameter from an activation descriptor. - /// \param [out] alpha Value of alpha parameter. - float get_alpha() const { return _alpha; } - /// Getting the beta parameter from an activation descriptor. - /// \param [out] beta Value of beta parameter. - float get_beta() const { return _beta; } - /// Getting the algorithm parameter from an activation descriptor. - /// \param [out] alg Activation algorithm. - ::dnnl::algorithm get_algorithm() const { return _alg; } -}; - -/// A class holding description for a local response normalization operation. -class lrn_desc { - unsigned int _local_size; - float _alpha; - float _beta; - float _k; - -public: - /// Setting a local response normalization descriptor with given parameters. - /// \param [in] local_size Value of local_size parameter. - /// \param [in] alpha Value of alpha parameter. - /// \param [in] beta Value of beta parameter. - /// \param [in] k Value of k parameter. - void set(unsigned int local_size, float alpha, float beta, float k) { - _local_size = local_size; - _alpha = alpha; - _beta = beta; - _k = k; - } - /// Getting parameters form a local response normalization descriptor. - /// \param [out] local_size Value of local_size parameter. - /// \param [out] alpha Value of alpha parameter. - /// \param [out] beta Value of beta parameter. - /// \param [out] k Value of k parameter. - void get(unsigned int *local_size, float *alpha, float *beta, - float *k) const { - *local_size = _local_size; - *alpha = _alpha; - *beta = _beta; - *k = _k; - } - /// Setting the local size parameter of a local response normalization - /// descriptor. - /// \param [in] local_size Value of local_size parameter. - void set_local_size(unsigned int local_size) { _local_size = local_size; } - /// Setting the alpha parameter of a local response normalization descriptor. - /// \param [in] alpha Value of alpha parameter. - void set_alpha(float alpha) { _alpha = alpha; } - /// Setting the beta parameter of a local response normalization descriptor. - /// \param [in] beta Value of beta parameter. - void set_beta(float beta) { _beta = beta; } - /// Setting the k parameter of a local response normalization descriptor. - /// \param [in] k Value of k parameter. - void set_k(float k) { _k = k; } - /// Getting the local size parameter from a local response normalization - /// descriptor. - /// \param [out] local_size Value of local_size parameter. - unsigned int get_local_size() const { return _local_size; } - /// Getting the alpha parameter from a local response normalization - /// descriptor. - /// \param [out] alpha Value of alpha parameter. - float get_alpha() const { return _alpha; } - /// Getting the beta parameter from a local response normalization descriptor. - /// \param [out] beta Value of beta parameter. - float get_beta() const { return _beta; } - /// Getting the k parameter from a local response normalization descriptor. - /// \param [out] k Value of k parameter. - float get_k() const { return _k; } -}; - -/// An enum class representing softmax algorithm. -enum class softmax_algorithm { normal, log }; -/// An enum class representing softmax mode. -enum class softmax_mode { instance, channel }; - -/// A class holding description for a pooling operation. -class pooling_desc { - ::dnnl::algorithm _alg; - std::vector _stride; - std::vector _kernel; - std::vector _padding; - -public: - /// Setting a 2D pooling descriptor with given parameters. - /// \param [in] alg Pooling algorithm. - /// \param [in] kernel_h Value of height of kernel. - /// \param [in] kernel_w Value of width of kernel. - /// \param [in] padding_h Value of height of padding. - /// \param [in] padding_w Value of width of padding. - /// \param [in] stride_h Value of height of stride. - /// \param [in] stride_w Value of width of stride. - void set(::dnnl::algorithm alg, int kernel_h, int kernel_w, int padding_h, - int padding_w, int stride_h, int stride_w) { - _alg = alg; - _stride = {stride_h, stride_w}; - _kernel = {kernel_h, kernel_w}; - _padding = {padding_h, padding_w}; - } - /// Setting a ND pooling descriptor with given parameters. - /// \param [in] alg Pooling algorithm. - /// \param [in] ndims Dimension of the pooling operation. - /// \param [in] kernel Array of dimension ndims containing the kernel size of - /// each dimension. - /// \param [in] padding Array of dimension ndims containing the padding size of - /// each dimension. - /// \param [in] stride Array of dimension ndims containing the stride size of - /// each dimension. - void set(::dnnl::algorithm alg, int ndims, int kernel[], int padding[], - int stride[]) { - _alg = alg; - _stride = std::vector(stride, stride + ndims); - _kernel = std::vector(kernel, kernel + ndims); - _padding = std::vector(padding, padding + ndims); - } - /// Getting parameters from a 2D pooling descriptor. - /// \param [out] alg Pooling algorithm. - /// \param [out] kernel_h Value of height of kernel. - /// \param [out] kernel_w Value of width of kernel. - /// \param [out] padding_h Value of height of padding. - /// \param [out] padding_w Value of width of padding. - /// \param [out] stride_h Value of height of stride. - /// \param [out] stride_w Value of width of stride. - void get(::dnnl::algorithm *alg, int *kernel_h, int *kernel_w, int *padding_h, - int *padding_w, int *stride_h, int *stride_w) const { - *alg = _alg; - *kernel_h = _kernel[0]; - *kernel_w = _kernel[1]; - *padding_h = _padding[0]; - *padding_w = _padding[1]; - *stride_h = _stride[0]; - *stride_w = _stride[1]; - } - /// Getting parameters from a ND pooling descriptor. - /// \param [in] requested_ndims Requested number of dimensions to get from a - /// given pooling descriptor. - /// \param [out] alg Pooling algorithm. - /// \param [out] ndims Dimension of the pooling operation. - /// \param [out] kernel Array of dimension ndims containing the kernel size of - /// each dimension. - /// \param [out] padding Array of dimension ndims containing the padding size - /// of each dimension. - /// \param [out] stride Array of dimension ndims containing the stride size of - /// each dimension. - void get(int requested_ndims, ::dnnl::algorithm *alg, int *ndims, - int kernel[], int padding[], int stride[]) const { - *alg = _alg; - *ndims = _stride.size(); - for (int i = 0; i < requested_ndims; i++) { - kernel[i] = _kernel[i]; - padding[i] = _padding[i]; - stride[i] = _stride[i]; - } - } - /// Setting the algorithm parameter of a pooling descriptor. - /// \param [in] alg Pooling algorithm. - void set_algorithm(::dnnl::algorithm alg) { _alg = alg; } - /// Setting the stride parameter of a pooling descriptor. - /// \param [in] stride Array of dimension ndims containing the stride size of - /// each dimension. - void set_stride(const std::vector &stride) { _stride = stride; } - /// Setting the kernel parameter of a pooling descriptor. - /// \param [in] kernel Array of dimension ndims containing the kernel size of - /// each dimension. - void set_kernel(const std::vector &kernel) { _kernel = kernel; } - /// Setting the padding parameter of a pooling descriptor. - /// \param [in] padding Array of dimension ndims containing the padding size - /// of each dimension. - void set_padding(const std::vector &padding) { _padding = padding; } - - /// Getting the algorithm parameter from a pooling descriptor. - /// \param [out] alg Pooling algorithm. - ::dnnl::algorithm get_algorithm() const { return _alg; } - /// Getting the stride parameter from a pooling descriptor. - /// \returns Array of dimension ndims containing the stride size of each - /// dimension. - const std::vector &get_stride() const { return _stride; } - /// Getting the kernel parameter from a pooling descriptor. - /// \returns Array of dimension ndims containing the kernel size of each - /// dimension. - const std::vector &get_kernel() const { return _kernel; } - /// Getting the padding parameter from a pooling descriptor. - /// \returns Array of dimension ndims containing the padding size of each - /// dimension. - const std::vector &get_padding() const { return _padding; } - /// Getting the output dimensions of a memory after 2D pooling has been - /// applied. - /// \param [in] desc Input memory descriptor. - /// \param [out] out_n Number of images. - /// \param [out] out_c Number of channels. - /// \param [out] out_h Height of images. - /// \param [out] out_w Width of images. - void get_forward_output_dim(const memory_desc_ext &desc, int *out_n, - int *out_c, int *out_h, int *out_w) const { - auto dims = desc.get_dims(); - *out_n = dims[0]; - *out_c = dims[1]; - *out_h = 1 + (dims[2] + 2 * _padding[0] - _kernel[0]) / _stride[0]; - *out_w = 1 + (dims[3] + 2 * _padding[1] - _kernel[1]) / _stride[1]; - } - /// Getting the output dimensions of a memory after ND pooling has been - /// applied. - /// \param [in] desc Input memory descriptor. - /// \param [out] ndims Dimension of the memory. - /// \param [out] out_dims Array of dimension requested_ndims that contain - /// the size of each memory dimension. - void get_forward_output_dim(const memory_desc_ext &desc, int ndims, - int out_dims[]) const { - assert(ndims >= 4 && "ndims is at least 4."); - auto dims = desc.get_dims(); - out_dims[0] = dims[0]; - out_dims[1] = dims[1]; - for (int i = 2; i < ndims; i++) { - out_dims[i] = - 1 + (dims[i] + 2 * _padding[i - 2] - _kernel[i - 2]) / _stride[i - 2]; - } - } -}; - -/// An enum class representing reduction operations. -enum class reduction_op { - max, - min, - sum, - mul, - mean, - amax, - mul_no_zeros, - norm1, - norm2 -}; - -/// An enum class representing batch normalization mode. -enum class batch_normalization_mode { per_activation, spatial }; - -/// An enum class representing batch normalization operations. -enum class batch_normalization_ops { none, activation, add_activation }; - -/// An enum class representing binary operations. -enum class binary_op { add, sub, mul, div, min, max, sqrt, neg }; - -/// An struct representing convolution algorithm infomation. -struct convolution_algorithm_info { - ::dnnl::algorithm algo = ::dnnl::algorithm::convolution_auto; - int status = 0; -}; - -/// A class holding description for a convolution operation. -class convolution_desc { - std::vector _strides; - std::vector _dilates; - std::vector _paddings; - int _group_count = 1; - ::dnnl::fpmath_mode _math_mode = ::dnnl::fpmath_mode::strict; -public: - /// Setting a group count to be used in the convolution. - /// \param [in] group_count Value of group count. - void set_group_count(int group_count) { _group_count = group_count; } - /// Getting a group count specified in the given convolution descriptor. - /// \returns Value of group count. - int get_group_count() { return _group_count; } - /// Setting floating point math mode to be used in the convolution. - /// \param [in] math_mode Value of math_mode. - void set_math_mode(::dnnl::fpmath_mode math_mode) { _math_mode = math_mode; } - /// Getting floating point math mode specified in the given convolution descriptor. - /// \returns Value of math mode. - ::dnnl::fpmath_mode get_math_mode() { return _math_mode; } - /// Setting a 2D convolution descriptor with given parameters. - /// \param [in] padding_h Value of height of padding. - /// \param [in] padding_w Value of width of padding. - /// \param [in] stride_h Value of height of stride. - /// \param [in] stride_w Value of width of stride. - /// \param [in] dilate_h Value of height of dilate. - /// \param [in] dilate_w Value of width of dilate. - void set(int padding_h, int padding_w, int stride_h, int stride_w, - int dilate_h, int dilate_w) { - _strides = {stride_h, stride_w}; - _dilates = {dilate_h - 1, dilate_w - 1}; - _paddings = {padding_h, padding_w}; - } - /// Setting a ND convolution descriptor with given parameters. - /// \param [in] ndims Dimension of the convolution operation. - /// \param [in] paddings Array of dimension ndims containing the padding size of - /// each dimension. - /// \param [in] strides Array of dimension ndims containing the stride size of - /// each dimension. - /// \param [in] dilates Array of dimension ndims containing the kernel size of - /// each dimension. - void set(int ndims, int paddings[], int strides[], int dilates[]) { - _strides = std::vector(strides, strides + ndims); - _paddings = std::vector(paddings, paddings + ndims); - _dilates = std::vector(dilates, dilates + ndims); - for (auto &dilate : _dilates) { - dilate--; - } - } - /// Getting parameters from a 2D convolution descriptor. - /// \param [out] padding_h Value of height of padding. - /// \param [out] padding_w Value of width of padding. - /// \param [out] stride_h Value of height of stride. - /// \param [out] stride_w Value of width of stride. - /// \param [out] dilate_h Value of height of dilate. - /// \param [out] dilate_w Value of width of dilate. - void get(int *padding_h, int *padding_w, int *stride_h, int *stride_w, - int *dilate_h, int *dilate_w) const { - *dilate_h = _dilates[0]; - *dilate_w = _dilates[1]; - *padding_h = _paddings[0]; - *padding_w = _paddings[1]; - *stride_h = _strides[0]; - *stride_w = _strides[1]; - } - /// Getting parameters from a ND convolution descriptor. - /// \param [in] requested_ndims Requested number of dimensions to get from a - /// given convolution descriptor. - /// \param [out] ndims Dimension of the pooling operation. - /// \param [out] paddings Array of dimension ndims containing the padding size - /// of each dimension. - /// \param [out] strides Array of dimension ndims containing the stride size of - /// each dimension. - /// \param [out] dilates Array of dimension ndims containing the dilate size of - /// each dimension. - void get(int requested_ndims, int *ndims, int paddings[], int strides[], - int dilates[]) const { - *ndims = _strides.size(); - for (int i = 0; i < requested_ndims; i++) { - dilates[i] = _dilates[i]; - paddings[i] = _paddings[i]; - strides[i] = _strides[i]; - } - } - /// Getting the stride parameter from a convolution descriptor. - /// \returns Array of dimension ndims containing the stride size of each - /// dimension. - const std::vector &get_stride() const { return _strides; } - /// Getting the kernel parameter from a convolution descriptor. - /// \returns Array of dimension ndims containing the dilate size of each - /// dimension. - const std::vector &get_dilate() const { return _dilates; } - /// Getting the padding parameter from a convolution descriptor. - /// \returns Array of dimension ndims containing the padding size of each - /// dimension. - const std::vector &get_padding() const { return _paddings; } - /// Getting the output dimensions of a memory after 2D convolution has been - /// applied. - /// \param [in] desc Input memory descriptor. - /// \param [in] weight_desc Input weight memory descriptor. - /// \param [out] out_n Number of images. - /// \param [out] out_c Number of channels. - /// \param [out] out_h Height of images. - /// \param [out] out_w Width of images. - void get_forward_output_dim(const memory_desc_ext &desc, - const memory_desc_ext &weight_desc, int *out_n, - int *out_c, int *out_h, int *out_w) const { - auto dims = desc.get_dims(); - auto weight_dims = weight_desc.get_dims(); - *out_n = dims[0]; - *out_c = weight_dims[0]; - *out_h = 1 + (dims[2] + 2 * _paddings[0] - - (1 + (_dilates[0] * (weight_dims[2] - 1)))) / - _strides[0]; - *out_w = 1 + (dims[3] + 2 * _paddings[1] - - (1 + (_dilates[1] * (weight_dims[3] - 1)))) / - _strides[1]; - } - /// Getting the output dimensions of a memory after ND convolution has been - /// applied. - /// \param [in] desc Input memory descriptor. - /// \param [in] weight_desc Input weight memory descriptor. - /// \param [out] ndims Dimension of the memory. - /// \param [out] out_dims Array of dimension requested_ndims that contain - /// the size of each memory dimension. - void get_forward_output_dim(const memory_desc_ext &desc, - const memory_desc_ext &weight_desc, int ndims, - int out_dims[]) const { - assert(ndims >= 4 && "ndims is at least 4."); - auto dims = desc.get_dims(); - auto weight_dims = weight_desc.get_dims(); - out_dims[0] = dims[0]; - out_dims[1] = weight_dims[1]; - for (int i = 2; i < ndims; i++) { - out_dims[i] = 1 + (dims[i] + 2 * _paddings[i - 2] - - (1 + (_dilates[i - 2] * (weight_dims[i] - 1)))) / - _strides[i - 2]; - } - } - - convolution_desc &operator=(std::nullptr_t) { - return *this = convolution_desc(); - } - - operator bool() const { - return !(_strides.size() == 0 - && _dilates.size() == 0 - && _paddings.size() == 0); - } -}; - -/// An enum class representing rnn mode. -enum class rnn_mode { vanilla_relu, vanilla_tanh, lstm, gru }; - -/// An enum class representing rnn bias mode. -enum class rnn_bias_mode { none, single }; - -/// An enum class representing rnn direction. -enum class rnn_direction {unidirectional, bidirectional}; - -/// A class holding description for a RNN operation. -class rnn_desc { - rnn_mode _mode; - rnn_bias_mode _bias_mode; - rnn_direction _direction; - dpct::library_data_t _dt; - int _input_size; - int _hidden_size; - int _projection_size; - int _layer_size; - -public: - void set(rnn_mode mode, rnn_bias_mode bias_mode, rnn_direction direction, - dpct::library_data_t dt, int input_size, int hidden_size, - int projection_size, int layer_size) { - _mode = mode; - _bias_mode = bias_mode; - _direction = direction; - _input_size = input_size; - _hidden_size = hidden_size; - _projection_size = projection_size; - _layer_size = layer_size; - _dt = dt; - } - void get(rnn_mode *mode, rnn_bias_mode *bias_mode, rnn_direction *direction, - dpct::library_data_t *dt, int *input_size, int *hidden_size, - int *projection_size, int *layer_size) const { - *mode = _mode; - *bias_mode = _bias_mode; - *direction = _direction; - *input_size = _input_size; - *hidden_size = _hidden_size; - *projection_size = _projection_size; - *layer_size = _layer_size; - *dt = _dt; - } -}; - -/// A class holding description for a Dropout operation. -class dropout_desc { - struct dropout_desc_imp { - float _p = 0.5f; - unsigned long long _seed = 1; - void *_state = nullptr; - std::vector _host_state; - rng_engine_t _rng_engine; - dropout_desc_imp() : _rng_engine(dpct::get_default_queue(), 1) {} - }; - std::shared_ptr _imp; - - void generate(sycl::queue *q, std::int64_t required_state_size, - std::int64_t num, void *buffer) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " - "Interfaces Project does not support this API."); -#else - sycl::event e_gen = oneapi::mkl::rng::generate( - oneapi::mkl::rng::bernoulli(1.f - _imp->_p), - _imp->_rng_engine, num, (std::int32_t *)buffer); - sycl::event e_save = q->submit([&](sycl::handler &cgh) { - cgh.depends_on(e_gen); - cgh.host_task([=] { - oneapi::mkl::rng::save_state(_imp->_rng_engine, - _imp->_host_state.data()); - }); - }); - q->memcpy(_imp->_state, _imp->_host_state.data(), required_state_size, - e_save); -#endif - } -public: - operator bool() const { - return bool(_imp); - } - dropout_desc &operator=(std::nullptr_t) { - _imp.reset(); - return *this; - } - /// Initializing a dropout descriptor. - void init(){ - _imp = std::make_shared(); - } - /// Setting a dropout descriptor with given parameters. - /// \param [in] engine Engine of the dropout operation. - /// \param [in] p Probability of value set to zero. - /// \param [in] state Memory that store random generator state. - /// \param [in] state_size Required size to store random generator state. - /// \param [in] seed Seed to initialize conditions of the generator state. - void set(engine_ext &engine, float p, void *state, size_t state_size, - unsigned long long seed); - /// Getting parameters from a dropout descriptor. - /// \param [in] engine Engine of the dropout operation. - /// \param [in] p Probability of value set to zero. - /// \param [in] state Memory that store random generator state. - /// \param [in] seed Seed to initialize conditions of the generator state. - void get(float *p, void **states, unsigned long long *seed) const noexcept { - *seed = _imp->_seed; - *states = _imp->_state; - *p = _imp->_p; - } - /// Getting the probability of value set to zero. - /// \returns Probability. - float get_probability() const noexcept { return _imp->_p; } - /// Restoreing a dropout descriptor from stored state. - /// \param [in] engine Engine of the dropout operation. - /// \param [in] p Probability of value set to zero. - /// \param [in] state Memory that store random generator state. - /// \param [in] state_size Required size to store random generator state. - /// \param [in] seed Seed to initialize conditions of the generator state. - void restore(engine_ext &engine, float p, void *state, size_t state_size, - unsigned long long seed); - friend class engine_ext; -}; - -namespace detail { -typedef std::string primitive_cache_key_type; -typedef std::list usage_list_type; -struct primitive_cache_value_type { - ::dnnl::primitive *_primitive; - std::unordered_map *_args; - usage_list_type::iterator _usage_it; - std::function _destructor; - sycl::event _e; - sycl::queue _q; - primitive_cache_value_type( - ::dnnl::primitive *primitive, - std::unordered_map *args, - usage_list_type::iterator usage_it, - std::function destructor, sycl::event e, - sycl::queue q) - : _primitive(primitive), _args(args), _usage_it(usage_it), - _destructor(destructor), _e(e), _q(q) {} -}; -struct primitive_and_args { - ::dnnl::primitive *primitive; - std::unordered_map *args; -}; -typedef std::unordered_map> - cache_map_type; - -// The primitive cache uses LRU replacement policy, and the default cache -// capacity is 1024. -class primitive_cache { - int _capacity = 1024; - usage_list_type usage; - cache_map_type cache_map; - void touch(cache_map_type::iterator it, sycl::event e = {}, - bool update_event = false) { - if (it->second->_usage_it != usage.begin()) { - const primitive_cache_key_type &key = it->first; - usage.erase(it->second->_usage_it); - usage.push_front(key); - it->second->_usage_it = usage.begin(); - } - if (update_event) { - it->second->_e = e; - } - } - -public: - std::shared_ptr - get(const primitive_cache_key_type &key) { - auto it = cache_map.find(key); - if (it == cache_map.end()) { - return nullptr; - } - touch(it); - return it->second; - } - void put(const primitive_cache_key_type &key, ::dnnl::primitive *value, - std::unordered_map *args, - std::function destructor, sycl::event e, - sycl::queue *q) { - auto it = cache_map.find(key); - if (it != cache_map.end()) { - touch(it, e, true); - } else { - if (cache_map.size() == _capacity) { - auto v = *(cache_map.find(usage.back())->second); - v._q.submit([=](sycl::handler &cgh) { - cgh.depends_on(v._e); - cgh.host_task([=] { - delete v._args; - v._destructor(v._primitive); - }); - }); - cache_map.erase(usage.back()); - usage.pop_back(); - } - usage.push_front(key); - cache_map[key] = std::make_shared( - value, args, usage.begin(), destructor, e, *q); - } - } -}; -} // namespace detail - -/// A class holding the oneDNN engine. -class engine_ext { - struct output_argument_info { - float _alpha; - float _beta; - int _name; - memory_desc_ext _desc; - void *_data; - output_argument_info(float alpha, float beta, int name, - memory_desc_ext desc, void *data) - : _alpha(alpha), _beta(beta), _name(name), _desc(desc), _data(data) {} - output_argument_info(float alpha, float beta, memory_desc_ext desc, - void *data) - : _alpha(alpha), _beta(beta), _name(0), _desc(desc), _data(data) {} - }; - struct buffer_info { - size_t capacity = 0; - uint8_t *buffer = nullptr; - size_t usage = 0; - sycl::queue q; - sycl::event deps; - size_t primitive_depth = 0; - }; - struct internal_resource { - std::int64_t random_engine_state_size = -1; - buffer_info binfo; - }; - std::shared_ptr<::dnnl::engine> _eng = nullptr; - std::shared_ptr<::dnnl::stream> _s = nullptr; - sycl::queue *_q = nullptr; - unsigned int _engine_id = 0; - static thread_local unsigned int _engine_count; - static thread_local std::map _workspace_map; - static thread_local std::map> - _internal_resource_cache; - static thread_local detail::primitive_cache _primitive_cache; - ::dnnl::memory &get_workspace(void *key) { return _workspace_map[key]; } - void insert_workspace(void *key, ::dnnl::memory workspace) { - _workspace_map[key] = workspace; - } - const ::dnnl::stream &get_stream() const { return *_s; } - const ::dnnl::engine &get_engine() const { return *_eng; } - - void *allocate(const memory_desc_ext &desc, int count = 1); - void *allocate(size_t size); - std::shared_ptr get_internal_resource(sycl::queue *q){ - auto it = _internal_resource_cache.find(_q); - if (it == _internal_resource_cache.end()) { - return _internal_resource_cache[_q] = std::make_shared(); - } - return it->second; - } - void enter_primitive(size_t request_buffer_size = 0) { - auto &info = get_internal_resource(_q)->binfo; - if (info.primitive_depth == 0) { - info.usage = 0; - if (request_buffer_size > info.capacity) { - if (info.buffer && (info.capacity != 0)) { - auto ainfo = info; - ainfo.q.submit([=](sycl::handler &cgh) { - cgh.depends_on(ainfo.deps); - cgh.host_task([=] { sycl::free(ainfo.buffer, ainfo.q); }); - }); - } - size_t new_buffer_capacity = - std::max(request_buffer_size, info.capacity * 2); - info.capacity = new_buffer_capacity; - info.buffer = (uint8_t *)sycl::malloc_device(new_buffer_capacity, *_q); - info.q = *_q; - info.deps = sycl::event(); - } - } - info.primitive_depth++; - } - sycl::event exit_primitive(const sycl::event &e) { - auto &info = get_internal_resource(_q)->binfo; - info.primitive_depth--; - if ((info.primitive_depth == 0) && info.usage) { - info.deps = e; - } - return e; - } - ::dnnl::memory::desc - compress_spatial_dimensions_to_channel(const ::dnnl::memory::desc &desc); - ::dnnl::memory::desc - get_bn_scale_bias_mean_var_desc(const ::dnnl::memory::desc &desc, - batch_normalization_mode mode); - sycl::event batch_normalization_backward_internal( - batch_normalization_mode mode, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta_data, - const memory_desc_ext &diff_src_desc, void *diff_src, float alpha_param, - const memory_desc_ext &diff_scale_bias_desc, void *scale, void *bias, - float beta_param, void *diff_scale, void *diff_bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var); - sycl::event batch_normalization_forward_internal( - bool is_infer, batch_normalization_mode mode, float epsilon, float factor, - float alpha, const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var, - void *running_mean, void *running_var); - ::dnnl::memory::desc - transfer_memory_desc_to_channel_major_format(const ::dnnl::memory::desc &desc); - ::dnnl::memory::desc - bn_reorder_memory_to_channel_major_format( - bool is_input, ::dnnl::memory::desc &desc, void *src, void **cache); - ::dnnl::memory::desc - transfer_memory_desc_to_format_tag_any(const ::dnnl::memory::desc &desc){ - return ::dnnl::memory::desc(desc.get_dims(), desc.get_data_type(), - ::dnnl::memory::format_tag::any); - } - void allocate_and_reorder_memory_to_optimal(::dnnl::memory::desc &from_desc, - void *&from, - ::dnnl::memory::desc &to_desc, - void *&to) { - if (from_desc != to_desc) { - to = allocate(to_desc); - async_reorder(1.f, from_desc, from, 0.f, to_desc, to); - } - } - template - std::pair - create_primitive_args_or_get(args_type &&...args); - template - typename primitive_type::primitive_desc - get_primitive_desc(::dnnl::primitive *p); - template - typename primitive_type::primitive_desc - create_primitive_desc(args_type &&...args); - template - void generate_cache_key(std::string &key_buffer, const T &arg); - template - void generate_cache_key(std::string &key_buffer, const T &first_arg, - const args_type &...args); - void insert_arg(std::unordered_map *args, int name, - const ::dnnl::memory::desc &desc, void *data) { - auto it = args->find(name); - if (it != args->end()) { - it->second.set_data_handle(data); - } else { - args->insert({name, ::dnnl::memory(desc, *_eng, data)}); - } - } - void insert_arg(std::unordered_map *args, int name, - const ::dnnl::memory &mem) { - (*args)[name] = mem; - } - sycl::event execute_rnn_forward_primitive( - rnn_mode mode, ::dnnl::prop_kind kind, ::dnnl::rnn_direction direction, - rnn_bias_mode bias_mode, ::dnnl::memory::data_type dt, - ::dnnl::memory::format_tag tag, int seq_length, int batch_size, int src_c, - int dst_c, int layer_size, int direction_num, int hidden_size, - int gate_num, int projection_size, std::vector &data, - std::vector &offset, int iter_num, size_t *weight_size = nullptr, - size_t *workspace_size = nullptr, size_t *scratchpad_size = nullptr); - - sycl::event rnn_forward_internal( - const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &iter_desc, void *src_iter, void *dst_iter, - const memory_desc_ext &iter_c_desc, void *src_iter_c, void *dst_iter_c, - size_t weight_size, void *weight, size_t workspace_size, void *workspace, - size_t scratchpad_size, void *scratchpad, bool is_get_execution_args, - size_t *weight_size_query, size_t *workspace_size_query, - size_t *scratchpad_size_query); - - sycl::event execute_rnn_backward_primitive( - rnn_mode mode, ::dnnl::rnn_direction direction, rnn_bias_mode bias_mode, - ::dnnl::memory::data_type dt, ::dnnl::memory::format_tag tag, - int seq_length, int batch_size, int src_c, int dst_c, int layer_size, - int direction_num, int hidden_size, int gate_num, int projection_size, - std::vector &data, std::vector &offset, int iter_num); - bool - scale_parameter_preprocess(const std::vector &args); - template - sycl::event - execute_primitive(const std::pair &primitive, - const std::vector &extra_args = {}); - template - sycl::event fill_with_type(sycl::queue *q, void *src, const void *value, - size_t size_with_byte) { - return q->fill(static_cast(src), *static_cast(value), - size_with_byte / sizeof(T)); - } - template struct no_zero_op { - T operator()(T e) { - if (!e) { - return 1; - } - return e; - } - }; - template - void transform_no_zero_with_type(sycl::queue *q, void *src, void *dst, - size_t num) { - std::transform(oneapi::dpl::execution::make_device_policy(*q), - static_cast(src), static_cast(src) + num, - static_cast(dst), no_zero_op()); - } - void transform_no_zero(const memory_desc_ext &desc, void *src, void *dst); - ::dnnl::memory::desc get_group_weight_desc(int group_count, - const memory_desc_ext &weight_desc); - void get_rnn_configuration(const ::dnnl::memory::desc &desc, - rnn_direction direction, rnn_mode mode, - dpct::library_data_t dt, int hidden_size, - ::dnnl::memory::data_type *dnnl_dt, - ::dnnl::memory::format_tag *tag, - int *projection_size, int *output_size, - int *seq_length, int *batch_size, - int *direction_num, int *gate_num); -public: - engine_ext() {} - operator bool() const { - return bool(_eng) && bool(_s) && bool(_q); - } - engine_ext &operator=(std::nullptr_t) { - _eng = nullptr; - _s = nullptr; - _q = nullptr; - return *this; - } - /// Creating oneDNN engine. - void create_engine() { - _q = &dpct::get_current_device().default_queue(); - _eng = std::make_shared<::dnnl::engine>(::dnnl::sycl_interop::make_engine( - dpct::get_current_device(), dpct::get_current_device().get_context())); - _s = std::make_shared<::dnnl::stream>( - ::dnnl::sycl_interop::make_stream(*_eng, *_q)); - _engine_id = _engine_count++; - } - /// Setting the user's SYCL queue for an oneDNN engine. - /// \param [in] q Pointer to the SYCL queue. - void set_queue(sycl::queue *q) { - if (!q) { - throw std::runtime_error("set_queue: pointer must not be nullptr."); - } - if (!_eng) { - throw std::runtime_error("set_queue: current engine is invalid."); - } - if (q->get_context() != ::dnnl::sycl_interop::get_context(*_eng)) { - throw std::runtime_error( - "set_queue: queue is mismatch with current engine context."); - } - _q = q; - _s = std::make_shared<::dnnl::stream>( - ::dnnl::sycl_interop::make_stream(*_eng, *_q)); - } - /// Retrieving the user's SYCL queue set in the oneDNN engine. - /// \returns Pointer to the SYCL queue. - sycl::queue *get_queue() const { return _q; } - /// Setting all elements of a memory to a given value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] valuePtr Pointer to a single value. - void fill(const memory_desc_ext &src_desc, void *src, - const void *valuePtr); - /// Coping the scaled data from a memory to another memory with a different - /// description. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - void reorder(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst); - /// Scaling all the elements of a memory by a given factor. - /// \param [in] alpha Value to scaling factors. - /// \param [in] src_desc Source memory descriptor. - /// \param [out] src Pointer to source data. - void scale(float alpha, const memory_desc_ext &src_desc, void *src); - /// Adding the scaled values of a memory to another memory. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - void sum(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst); - /// Computing a specified activation function value. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - void activation_forward(activation_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst); - /// Computing the gradient of a specified activation function. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - void - activation_backward(activation_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src); - /// Computing a specified pooling function value. - /// \param [in] desc Pooling descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [out] workspace Pointer to workspace generated from forward propagation. - void pooling_forward(pooling_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace = nullptr); - /// Computing the gradient of a specified pooling function. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential - /// source data. - /// \param [in] workspace Pointer to workspace used for backward - /// propagation. - void pooling_backward(pooling_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src, - ::dnnl::memory *workspace = nullptr); - /// Computing a specified softmax function value. - /// \param [in] alg Softmax algorithm. - /// \param [in] mode Softmax mode. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - void softmax_forward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, void *dst); - /// Computing the gradient of a specified softmax function. - /// \param [in] alg Softmax algorithm. - /// \param [in] mode Softmax mode. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - void softmax_backward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, - void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src); - /// Computing a specified local response normalization function value. - /// \param [in] desc Local response normalization descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [out] workspace Pointer to workspace generated from forward - /// propagation. - void lrn_forward(lrn_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace = nullptr); - /// Computing the gradient of a specified local response normalization - /// function. - /// \param [in] desc Local response normalization descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] workspace Pointer to workspace used for backward propagation. - void lrn_backward(lrn_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &diff_src_desc, - void *diff_src, ::dnnl::memory *workspace = nullptr); - /// Setting all elements of a memory to a given value asynchronously. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] valuePtr Pointer to a single value. - /// \returns An event representing the fill operations. - sycl::event async_fill(const memory_desc_ext &src_desc, void *src, - const void *valuePtr); - /// Coping the scaled data from a memory to another memory with a different - /// description asynchronously. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the reorder operations. - sycl::event async_reorder(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst); - /// Scaling all the elements of a memory by a given factor asynchronously. - /// \param [in] alpha Value to scaling factors. - /// \param [in] src_desc Source memory descriptor. - /// \param [out] src Pointer to source data. - /// \returns An event representing the scale operations. - sycl::event async_scale(float alpha, const memory_desc_ext &src_desc, void *src); - /// Adding the scaled values of a memory to another memory asynchronously. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the sum operations. - sycl::event async_sum(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst); - - /// Perform specified binary operation asynchronously. - /// \param [in] op Specified binary operation. - /// \param [in] alpha_0 Value to scaling factors used to scale the src_0 - /// value. - /// \param [in] src_desc_0 Source 0 memory descriptor. - /// \param [in] src_0 Pointer to source 0 data. - /// \param [in] alpha_1 Value to scaling factors used to scale the src_1 - /// value. - /// \param [in] src_desc_1 Source 1 memory descriptor. - /// \param [in] src_1 Pointer to source 1 data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the binary operations. - sycl::event async_binary(binary_op op, float alpha_0, - const memory_desc_ext &src_desc_0, void *src_0, - float alpha_1, const memory_desc_ext &src_desc_1, - void *src_1, float beta, const memory_desc_ext &dst_desc, - void *dst); - - /// Perform specified binary operation asynchronously. - /// \param [in] op Specified reduction operation. - /// \param [in] alpha Value to scaling factors used to scale the data - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the reduction operations. - sycl::event async_reduction(reduction_op op, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst); - /// Computing a specified activation function value asynchronously. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the activation forward operations. - sycl::event async_activation_forward(activation_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst); - /// Computing the gradient of a specified activation function asynchronously. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \returns An event representing the activation backward operations. - sycl::event - async_activation_backward(activation_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src); - /// Computing a specified pooling function value asynchronously. - /// \param [in] desc Pooling descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [out] workspace Pointer to workspace generated from forward propagation. - /// \returns An event representing the pooling forward operations. - sycl::event async_pooling_forward(pooling_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace = nullptr); - /// Computing the gradient of a specified pooling function asynchronously. - /// \param [in] desc Activation descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential - /// source data. - /// \param [in] workspace Pointer to workspace used for backward - /// propagation. - /// \returns An event representing the pooling backward operations. - sycl::event async_pooling_backward(pooling_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src, - ::dnnl::memory *workspace = nullptr); - /// Computing a specified softmax function value asynchronously. - /// \param [in] alg Softmax algorithm. - /// \param [in] mode Softmax mode. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the softmax forward operations. - sycl::event async_softmax_forward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, void *dst); - /// Computing the gradient of a specified softmax function asynchronously. - /// \param [in] alg Softmax algorithm. - /// \param [in] mode Softmax mode. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \returns An event representing the softmax backward operations. - sycl::event async_softmax_backward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, - void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src); - /// Computing a specified local response normalization function value - /// asynchronously. - /// \param [in] desc Local response normalization descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [out] workspace Pointer to workspace generated from forward - /// propagation. - /// \returns An event representing the lrn forward operations. - sycl::event async_lrn_forward(lrn_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace = nullptr); - /// Computing the gradient of a specified local response normalization - /// function asynchronously. - /// \param [in] desc Local response normalization descriptor. - /// \param [in] alpha Value to scaling factors used to scale the computed value. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the differential destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] workspace Pointer to workspace used for backward propagation. - /// \returns An event representing the lrn backward operations. - sycl::event async_lrn_backward(lrn_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &diff_src_desc, - void *diff_src, ::dnnl::memory *workspace = nullptr); - - /// Derives a memory descriptor for the batch normalization scale, bias, mean, - /// variance from the source memory descriptor and batch normalization mode. - /// \param [out] desc Derived memory descriptor. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] mode Batch normalization mode. - static void derive_batch_normalization_memory_desc(memory_desc_ext &desc, - const memory_desc_ext &src_desc, - batch_normalization_mode mode); - - /// Derives a memory descriptor for the batch normalization scale, bias, mean, - /// variance from the source memory descriptor and batch normalization mode. - /// \param [out] scale_bias_desc Derived scale and bias memory descriptor. - /// \param [out] mean_var_desc Derived mean and var memory descriptor. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] mode Batch normalization mode. - static void derive_batch_normalization_memory_desc(memory_desc_ext &scale_bias_desc, - memory_desc_ext &mean_var_desc, - const memory_desc_ext &src_desc, - batch_normalization_mode mode); - - /// Get the size of workspace that needed by batch normalization. The data stored - /// in workspace must be preserved between forward and backward. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] src_desc Source memory descriptor. - /// \returns Size of workspace. - size_t get_batch_normalization_workspace_size( - batch_normalization_ops ops, const memory_desc_ext &src_desc); - - /// Computing a specified batch normalization inference stage function value - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] scale_bias_mean_var_desc Scale, bias, mean, variance memory - /// descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [in] mean Pointer to mean data. - /// \param [in] var Pointer to variance data. - /// \returns An event representing the batch normalization forward operations. - sycl::event async_batch_normalization_forward_inference( - batch_normalization_mode mode, float epsilon, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *mean, void *var); - - /// Computing a specified batch normalization inference stage function value - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] summand_desc Summand memory descriptor. - /// \param [in] summand Pointer to summand data. - /// \param [in] scale_bias_desc Scale, bias memory descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [in] mean_var_desc Mean, variance memory descriptor. - /// \param [in] mean Pointer to mean data. - /// \param [in] var Pointer to variance data. - /// \returns An event representing the batch normalization forward operations. - sycl::event async_batch_normalization_forward_inference( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *mean, void *var); - - /// Computing a specified batch normalization training stage function value - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] factor Factor value used in running mean and variance - /// computation. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] scale_bias_mean_var_desc Scale, bias, mean, variance memory - /// descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [out] running_mean Pointer to running mean data. - /// \param [out] running_var Pointer to running variance data. - /// \param [out] saved_mean Pointer to optional cache to save mean data. - /// \param [out] saved_var Pointer to optional cache to save variance data. - /// \returns An event representing the batch normalization forward operations. - sycl::event async_batch_normalization_forward_training( - batch_normalization_mode mode, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *running_mean, void *running_var, void *saved_mean, void *saved_var); - - /// Computing a specified batch normalization training stage function value - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] factor Factor value used in running mean and variance - /// computation. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] summand_desc Summand memory descriptor. - /// \param [in] summand Pointer to summand data. - /// \param [in] scale_bias_mean_var_desc Scale, bias, mean, variance memory - /// descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [out] running_mean Pointer to running mean data. - /// \param [out] running_var Pointer to running variance data. - /// \param [out] saved_mean Pointer to optional cache to save mean data. - /// \param [out] saved_var Pointer to optional cache to save variance data. - /// \param [in] workspace_size Size of workspace. - /// \param [out] workspace Pointer to workspace generated from forward - /// propagation. - /// \returns An event representing the batch normalization forward operations. - sycl::event async_batch_normalization_forward_training( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *running_mean, void *running_var, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace); - - /// Computing a specified batch normalization training stage function value - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] factor Factor value used in running mean and variance - /// computation. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] summand_desc Summand memory descriptor. - /// \param [in] summand Pointer to summand data. - /// \param [in] scale_bias_desc Scale, bias memory descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [in] mean_var_desc Mean, variance memory descriptor. - /// \param [out] running_mean Pointer to running mean data. - /// \param [out] running_var Pointer to running variance data. - /// \param [out] saved_mean Pointer to optional cache to save mean data. - /// \param [out] saved_var Pointer to optional cache to save variance data. - /// \param [in] workspace_size Size of workspace. - /// \param [out] workspace Pointer to workspace generated from forward - /// propagation. - /// \returns An event representing the batch normalization forward operations. - sycl::event async_batch_normalization_forward_training( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *running_mean, void *running_var, - void *saved_mean, void *saved_var, size_t workspace_size, void *workspace); - - /// Computing the gradient of a specified batch normalization function asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] alpha_data Value to scaling factors used to scale the computed - /// data value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta_data Value to scaling factors used to scale the prior value - /// in the data memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] alpha_param Value to scaling factors used to scale the computed - /// parameter value. - /// \param [in] diff_scale_bias_mean_var_desc Differential scale, bias, mean, - /// variance memory descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] beta_param Value to scaling factors used to scale the prior value - /// in the parameter memory. - /// \param [in] diff_scale Pointer to differential scale data. - /// \param [in] diff_bias Pointer to differential bias data. - /// \param [in] saved_mean Pointer to optional cache saved mean data in forward. - /// \param [in] saved_var Pointer to optional cache saved variance data in forward. - /// \returns An event representing the batch normalization backward operations. - sycl::event async_batch_normalization_backward( - batch_normalization_mode mode, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta_data, - const memory_desc_ext &diff_src_desc, void *diff_src, float alpha_param, - const memory_desc_ext &diff_scale_bias_mean_var_desc, void *scale, - float beta_param, void *diff_scale, void *diff_bias, void *saved_mean, - void *saved_var); - - /// Computing the gradient of a specified batch normalization function - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] alpha_data Value to scaling factors used to scale the computed - /// data value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta_data Value to scaling factors used to scale the prior value - /// in the data memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] diff_summand_desc Differential summand memory descriptor. - /// \param [out] diff_summand Pointer to differential summand data. - /// \param [in] alpha_param Value to scaling factors used to scale the computed - /// parameter value. - /// \param [in] diff_scale_bias_mean_var_desc Differential scale, bias, mean, - /// variance memory descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [in] beta_param Value to scaling factors used to scale the prior value - /// in the parameter memory. - /// \param [out] diff_scale Pointer to differential scale data. - /// \param [out] diff_bias Pointer to differential bias data. - /// \param [in] saved_mean Pointer to optional cache saved mean data in forward. - /// \param [in] saved_var Pointer to optional cache saved variance data in forward. - /// \param [in] workspace_size Size of workspace. - /// \param [in] workspace Pointer to workspace used for backward propagation. - /// \returns An event representing the batch normalization backward operations. - sycl::event async_batch_normalization_backward( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta_data, - const memory_desc_ext &diff_src_desc, void *diff_src, - const memory_desc_ext &diff_summand_desc, void *diff_summand, - float alpha_param, const memory_desc_ext &diff_scale_bias_mean_var_desc, - void *scale, void *bias, float beta_param, void *diff_scale, - void *diff_bias, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace); - - /// Computing the gradient of a specified batch normalization function - /// asynchronously. - /// \param [in] mode Batch normalization mode. - /// \param [in] ops Batch normalization operation mode. This mode can set to - /// perform only batch normalization, or batch normalization followed by - /// activation, or batch normalization followed by element-wise addition and - /// activation. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] epsilon Epsilon value used in computation. - /// \param [in] alpha_data Value to scaling factors used to scale the computed - /// data value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta_data Value to scaling factors used to scale the prior value - /// in the data memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] diff_summand_desc Differential summand memory descriptor. - /// \param [out] diff_summand Pointer to differential summand data. - /// \param [in] alpha_param Value to scaling factors used to scale the computed - /// parameter value. - /// \param [in] diff_scale_bias_desc Differential scale, bias memory descriptor. - /// \param [in] scale Pointer to scale data. - /// \param [in] bias Pointer to bias data. - /// \param [in] beta_param Value to scaling factors used to scale the prior value - /// in the parameter memory. - /// \param [out] diff_scale Pointer to differential scale data. - /// \param [out] diff_bias Pointer to differential bias data. - /// \param [in] mean_var_desc Differential mean, variance memory descriptor. - /// \param [in] saved_mean Pointer to optional cache saved mean data in forward. - /// \param [in] saved_var Pointer to optional cache saved variance data in forward. - /// \param [in] workspace_size Size of workspace. - /// \param [in] workspace Pointer to workspace used for backward propagation. - /// \returns An event representing the batch normalization backward operations. - sycl::event async_batch_normalization_backward( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, void *diff_dst, - float beta_data, const memory_desc_ext &diff_src_desc, void *diff_src, - const memory_desc_ext &diff_summand_desc, void *diff_summand, - float alpha_param, const memory_desc_ext &diff_scale_bias_desc, void *scale, - void *bias, float beta_param, void *diff_scale, void *diff_bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace); - - /// Computing a specified convolution function value asynchronously. - /// \param [in] desc Convolution descriptor. - /// \param [in] alg Convolution algorithm. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] weight_desc Weight memory descriptor. - /// \param [in] weight Pointer to weight data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the convolution forward operations. - sycl::event async_convolution_forward(convolution_desc &desc, ::dnnl::algorithm alg, - float alpha, const memory_desc_ext &src_desc, - void *src, const memory_desc_ext &weight_desc, - void *weight, float beta, - const memory_desc_ext &dst_desc, void *dst); - - /// Computing a specified convolution function value asynchronously. - /// \param [in] desc Convolution descriptor. - /// \param [in] alg Convolution algorithm. - /// \param [in] adesc Activation operation descriptor. - /// \param [in] alpha_0 Value to scaling factors used to scale the data - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] weight_desc Weight memory descriptor. - /// \param [in] weight Pointer to weight data. - /// \param [in] alpha_1 Value to scaling factors used to scale the summand - /// value. - /// \param [in] summand_desc Summand memory descriptor. - /// \param [in] summand Pointer to summand data. - /// \param [in] bias_desc Bias memory descriptor. - /// \param [in] bias Pointer to bias data. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \returns An event representing the convolution forward operations. - sycl::event async_convolution_forward( - convolution_desc &desc, ::dnnl::algorithm alg, activation_desc &adesc, - float alpha_0, const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &weight_desc, void *weight, float alpha_1, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &bias_desc, void *bias, - const memory_desc_ext &dst_desc, void *dst); - - /// Computing the data gradient of a specified convolution function asynchronously. - /// \param [in] desc Convolution descriptor. - /// \param [in] alg Convolution algorithm. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] weight_desc Weight memory descriptor. - /// \param [in] weight Pointer to weight data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \returns An event representing the convolution backward data operations. - sycl::event async_convolution_backward_data( - convolution_desc &desc, ::dnnl::algorithm alg, float alpha, - const memory_desc_ext &weight_desc, void *weight, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src); - - /// Computing the weight gradient of a specified convolution function - /// asynchronously. - /// \param [in] desc Convolution descriptor. - /// \param [in] alg Convolution algorithm. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] diff_weight_desc Differential weight memory descriptor. - /// \param [out] diff_weight Pointer to differential weight data. - /// \returns An event representing the convolution backward weight operations. - sycl::event async_convolution_backward_weight( - convolution_desc &desc, ::dnnl::algorithm alg, float alpha, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta, - const memory_desc_ext &diff_weight_desc, void *diff_weight); - - /// Computing the bias gradient of a specified convolution function - /// asynchronously. - /// \param [in] alpha Value to scaling factors used to scale the computed - /// value. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] beta Value to scaling factors used to scale the prior value - /// in the destination memory. - /// \param [in] diff_bias_desc Differential bias memory descriptor. - /// \param [out] diff_bias Pointer to differential bias data. - /// \returns An event representing the convolution backward bias operations. - sycl::event async_convolution_backward_bias(float alpha, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, float beta, - const memory_desc_ext &diff_bias_desc, - void *diff_bias); - - /// Getting the required weight space size for specified rnn operation. - /// \param [in] desc RNN descriptor. - /// \param [out] weight_space_size Size of required weight space. - void rnn_get_weight_space_size(const rnn_desc &desc, - size_t *weight_space_size); - - /// Getting the required scratchpad size and workspace size for specified rnn operation. - /// \param [in] desc RNN descriptor. - /// \param [in] kind Propagation kind. - /// \param [in] src_desc Source memory descriptor. - /// \param [out] scratchpad_size Size of required scratchpad. - /// \param [out] workspace_size Size of required workspace. - void rnn_get_scratchpad_workspace_size(const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, - size_t *scratchpad_size, size_t *workspace_size); - - /// Computing a specified rnn function value asynchronously. - /// \param [in] desc RNN descriptor. - /// \param [in] kind Propagation kind. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] iter_desc Recurrent hidden state data memory descriptor. - /// \param [in] src_iter Pointer to input recurrent hidden state data. - /// \param [in] dst_iter Pointer to output recurrent hidden state data. - /// \param [in] iter_c_desc Recurrent cell state data memory descriptor. - /// \param [in] src_c_iter Pointer to input recurrent cell state data. - /// \param [in] dst_c_iter Pointer to output recurrent cell state data. - /// \param [in] weight_size Size of weight memory. - /// \param [in] weight Pointer to weight data. - /// \param [in] scratchpad_size Size of scratchpad memory. - /// \param [in] scratchpad Pointer to scratchpad data. - /// \param [in] workspace_size Size of workspace memory. - /// \param [in] workspace Pointer to workspace data. - /// \returns An event representing the status of rnn forward operations. - sycl::event async_rnn_forward(const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &iter_desc, void *src_iter, - void *dst_iter, - const memory_desc_ext &iter_c_desc, - void *src_iter_c, void *dst_iter_c, - size_t weight_size, void *weight, - size_t scratchpad_size, void *scratchpad, - size_t workspace_size, void *workspace); - - /// Computing the data and weight gradient of a specified rnn function - /// asynchronously. - /// \param [in] desc RNN descriptor. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [in] dst Pointer to destination data. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] iter_desc Recurrent hidden state data memory descriptor. - /// \param [in] src_iter Pointer to input recurrent hidden state data. - /// \param [in] diff_dst_iter Pointer to differential output recurrent hidden state data. - /// \param [out] diff_src_iter Pointer to differential input recurrent hidden state data. - /// \param [in] iter_c_desc Recurrent cell state data memory descriptor. - /// \param [in] src_c_iter Pointer to input recurrent cell state data. - /// \param [in] diff_dst_c_iter Pointer to differential output recurrent cell state data. - /// \param [out] diff_src_c_iter Pointer to differential input recurrent cell state data. - /// \param [in] weight_size Size of weight memory. - /// \param [in] weight Pointer to weight data. - /// \param [out] diff_weight Pointer to differential weight data. - /// \param [in] scratchpad_size Size of scratchpad memory. - /// \param [in] scratchpad Pointer to scratchpad data. - /// \param [in] workspace_size Size of workspace memory. - /// \param [in] workspace Pointer to workspace data. - /// \returns An event representing the status of rnn backward operations. - sycl::event async_rnn_backward( - const rnn_desc &desc, const memory_desc_ext &dst_desc, void *dst, - void *diff_dst, const memory_desc_ext &src_desc, void *src, - void *diff_src, const memory_desc_ext &iter_desc, void *src_iter, - void *diff_dst_iter, void *diff_src_iter, - const memory_desc_ext &iter_c_desc, void *src_iter_c, - void *diff_dst_iter_c, void *diff_src_iter_c, size_t weight_size, - void *weight, void *diff_weight, size_t scratchpad_size, void *scratchpad, - size_t workspace_size, void *workspace); - - /// Getting the required state size for specified dropout operation. - /// \param [in] src_desc Source memory descriptor. - /// \returns Required size of state. - size_t get_dropout_state_size(); - - /// Getting the required workspace size for dropout operation. - /// \param [in] src_desc Source memory descriptor. - /// \returns Required size of workspace. - static size_t get_dropout_workspace_size(const memory_desc_ext &src_desc); - - /// Computing a specified dropout function value asynchronously. - /// \param [in] desc Dropout descriptor. - /// \param [in] src_desc Source memory descriptor. - /// \param [in] src Pointer to source data. - /// \param [in] dst_desc Destination memory descriptor. - /// \param [out] dst Pointer to destination data. - /// \param [in] workspace Pointer to workspace data. - /// \param [in] workspace_size Size of workspace memory. - /// \returns An event representing the dropout forward operations. - sycl::event async_dropout_forward(dropout_desc &desc, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &dst_desc, void *dst, - void *workspace, size_t workspace_size); - - /// Computing the gradient of a specified dropout function asynchronously. - /// \param [in] desc Dropout descriptor. - /// \param [in] diff_dst_desc Differential destination memory descriptor. - /// \param [in] diff_dst Pointer to differential destination data. - /// \param [in] diff_src_desc Differential source memory descriptor. - /// \param [out] diff_src Pointer to differential source data. - /// \param [in] workspace Pointer to workspace data. - /// \param [in] workspace_size Size of workspace memory. - /// \returns An event representing the dropout backward operations. - sycl::event async_dropout_backward(dropout_desc &desc, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, - const memory_desc_ext &diff_src_desc, - void *diff_src, void *workspace, - size_t workspace_size); -}; - -inline thread_local unsigned int engine_ext::_engine_count; -inline thread_local detail::primitive_cache engine_ext::_primitive_cache; -inline thread_local std::map engine_ext::_workspace_map; -inline thread_local std::map> - engine_ext::_internal_resource_cache; - -inline -void dropout_desc::restore(engine_ext &engine, float p, void *state, - size_t state_size, unsigned long long seed) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " - "Interfaces Project does not support this API."); -#else - if (state) { - std::int64_t required_state_size = engine.get_dropout_state_size(); - if (state_size < required_state_size) { - throw std::runtime_error("restore: state_size less than required state size."); - } - sycl::queue *q = engine.get_queue(); - _imp->_p = p; - _imp->_seed = seed; - _imp->_state = state; - _imp->_host_state = std::vector(required_state_size); - q->memcpy(_imp->_host_state.data(), _imp->_state, required_state_size).wait(); - _imp->_rng_engine = - oneapi::mkl::rng::load_state( - *q, _imp->_host_state.data()); - } -#endif -} - -inline -void dropout_desc::set(engine_ext &engine, float p, void *state, - size_t state_size, unsigned long long seed) { -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " - "Interfaces Project does not support this API."); -#else - _imp->_p = p; - if (state) { - std::int64_t required_state_size = engine.get_dropout_state_size(); - if (state_size < required_state_size) { - throw std::runtime_error("set: no sufficient memory to save states."); - } - sycl::queue *q = engine.get_queue(); - _imp->_seed = seed; - _imp->_state = state; - _imp->_host_state = std::vector(required_state_size); - _imp->_rng_engine = rng_engine_t(*q, seed); - oneapi::mkl::rng::save_state(_imp->_rng_engine, _imp->_host_state.data()); - q->memcpy(_imp->_state, _imp->_host_state.data(), required_state_size).wait(); - } -#endif -} - -inline -::dnnl::memory::data_type -memory_desc_ext::to_dnnl_data_type(dpct::library_data_t dt) { - using dnnl_dt = ::dnnl::memory::data_type; - switch (dt) { - case dpct::library_data_t::real_half: - return dnnl_dt::f16; - case dpct::library_data_t::real_bfloat16: - return dnnl_dt::bf16; - case dpct::library_data_t::real_float: - return dnnl_dt::f32; - case dpct::library_data_t::real_double: - return dnnl_dt::f64; - case dpct::library_data_t::real_int32: - return dnnl_dt::s32; - case dpct::library_data_t::real_int8: - return dnnl_dt::s8; - case dpct::library_data_t::real_uint8: - return dnnl_dt::u8; - case dpct::library_data_t::real_int8_4: - return dnnl_dt::s8; - case dpct::library_data_t::real_int8_32: - return dnnl_dt::s8; - case dpct::library_data_t::real_uint8_4: - return dnnl_dt::u8; - default: - throw std::runtime_error("to_dnnl_data_type: unsupported data type."); - } -} - -inline -dpct::library_data_t -memory_desc_ext::to_dpct_library_data_t(::dnnl::memory::data_type dt, - unsigned block_size) { - using dpct_dt = dpct::library_data_t; - using dnnl_dt = ::dnnl::memory::data_type; - switch (dt) { - case dnnl_dt::f16: - return dpct_dt::real_half; - case dnnl_dt::bf16: - return dpct_dt::real_bfloat16; - case dnnl_dt::f32: - return dpct_dt::real_float; - case dnnl_dt::f64: - return dpct_dt::real_double; - case dnnl_dt::s32: - return dpct_dt::real_int32; - case dnnl_dt::s8: - if (block_size == 4) { - return dpct_dt::real_int8_4; - } else if (block_size == 32) { - return dpct_dt::real_int8_32; - } else { - return dpct_dt::real_int8; - } - case dnnl_dt::u8: - if (block_size == 4) { - return dpct_dt::real_uint8_4; - } else { - return dpct_dt::real_uint8; - } - default: - throw std::runtime_error("to_dpct_library_data_t: unsupported data type " - "dnnl::memory::data_type::undef."); - } -} - -inline -::dnnl::memory::format_tag -memory_desc_ext::to_dnnl_format_tag(dpct::library_data_t dt, - memory_format_tag tag) { - using dpct_dt = dpct::library_data_t; - using dpct_tag = memory_format_tag; - using dnnl_tag = ::dnnl::memory::format_tag; - switch (tag) { - case dpct_tag::nchw: - return dnnl_tag::nchw; - case dpct_tag::nhwc: - return dnnl_tag::nhwc; - default: - if (dt == dpct_dt::real_int8_32) { - return dnnl_tag::nChw32c; - } else { - return dnnl_tag::nChw4c; - } - } -} - -inline -void memory_desc_ext::set(memory_format_tag tag, dpct::library_data_t dt, int n, - int c, int h, int w) { - _desc = ::dnnl::memory::desc({n, c, h, w}, to_dnnl_data_type(dt), - to_dnnl_format_tag(dt, tag)); -} - -inline -void memory_desc_ext::set(dpct::library_data_t dt, int n, int c, int h, int w, - int n_stride, int c_stride, int h_stride, - int w_stride) { - _desc = ::dnnl::memory::desc({n, c, h, w}, to_dnnl_data_type(dt), - {n_stride, c_stride, h_stride, w_stride}); -} - -inline -void memory_desc_ext::set(dpct::library_data_t dt, int ndims, const int dims[], - const int strides[]) { - _desc = ::dnnl::memory::desc({dims, dims + ndims}, to_dnnl_data_type(dt), - {strides, strides + ndims}); -} - -inline -void memory_desc_ext::set(memory_format_tag tag, dpct::library_data_t dt, - int ndims, const int dims[]) { - _desc = ::dnnl::memory::desc({dims, dims + ndims}, to_dnnl_data_type(dt), - to_dnnl_format_tag(dt, tag)); -} - -inline -void memory_desc_ext::set(rnn_memory_format_tag tag, dpct::library_data_t dt, - int t, int n, int c) { - if (tag == rnn_memory_format_tag::tnc) { - _desc = ::dnnl::memory::desc({t, n, c}, to_dnnl_data_type(dt), - ::dnnl::memory::format_tag::tnc); - } else if(tag == rnn_memory_format_tag::ntc) { - _desc = ::dnnl::memory::desc({t, n, c}, to_dnnl_data_type(dt), - ::dnnl::memory::format_tag::ntc); - } else { - throw std::runtime_error("set: unsupported memory format tag."); - } -} - -inline -void memory_desc_ext::get(dpct::library_data_t *dt, int *n, int *c, int *h, - int *w, int *n_stride, int *c_stride, int *h_stride, - int *w_stride) const { - unsigned block_size = 1; - auto dims = _desc.get_dims(); - auto inner_blks = _desc.get_inner_blks(); - auto strides = _desc.get_strides(); - if (!inner_blks.empty()) { - block_size = inner_blks[0]; - } - - *dt = to_dpct_library_data_t(_desc.get_data_type(), block_size); - *n = dims[0]; - *c = dims[1]; - *h = dims[2]; - *w = dims[3]; - *n_stride = strides[0] / block_size; - *c_stride = strides[1] / block_size; - *h_stride = strides[2] / block_size; - *w_stride = strides[3] / block_size; -} - -inline -void memory_desc_ext::get(dpct::library_data_t *dt, memory_format_tag *tag, - int *n, int *c, int *h, int *w) const { - unsigned block_size = 1; - *tag = memory_format_tag::nchw; - auto dims = _desc.get_dims(); - auto strides = _desc.get_strides(); - auto inner_blks = _desc.get_inner_blks(); - if (!inner_blks.empty()) { - block_size = inner_blks[0]; - *tag = memory_format_tag::nchw_blocked; - } - if (strides[1] == 1 && dims[1] != 1) { - *tag = memory_format_tag::nhwc; - } - *dt = to_dpct_library_data_t(_desc.get_data_type(), block_size); - *n = dims[0]; - *c = dims[1]; - *h = dims[2]; - *w = dims[3]; -} - -inline -void memory_desc_ext::get(dpct::library_data_t *dt, rnn_memory_format_tag *tag, - int *t, int *n, int *c) const { - auto dims = _desc.get_dims(); - auto strides = _desc.get_strides(); - - if (strides[0] >= strides[1]) { - *tag = rnn_memory_format_tag::tnc; - } else { - *tag = rnn_memory_format_tag::ntc; - } - - *dt = to_dpct_library_data_t(_desc.get_data_type(), 1); - *t = dims[0]; - *n = dims[1]; - *c = dims[2]; -} - -inline -void memory_desc_ext::get(int requested_ndims, dpct::library_data_t *dt, - int *ndims, int dims[], int strides[]) const { - unsigned block_size = 1; - auto inner_blks = _desc.get_inner_blks(); - auto adims = _desc.get_dims(); - auto astrides = _desc.get_strides(); - if (!inner_blks.empty()) { - block_size = inner_blks[0]; - } - *dt = to_dpct_library_data_t(_desc.get_data_type(), block_size); - *ndims = _desc.get_ndims(); - for (int index = 0; index < requested_ndims; index++) { - dims[index] = adims[index]; - strides[index] = - astrides[index] / block_size; - } -} - -inline -void memory_desc_ext::get(int requested_ndims, dpct::library_data_t *dt, - memory_format_tag *tag, int *ndims, - int dims[]) const { - unsigned block_size = 1; - *tag = memory_format_tag::nchw; - auto inner_blks = _desc.get_inner_blks(); - auto adims = _desc.get_dims(); - auto astrides = _desc.get_strides(); - if (!inner_blks.empty()) { - block_size = inner_blks[0]; - *tag = memory_format_tag::nchw_blocked; - } - if (astrides[1] == 1 && - adims[1] != 1) { - *tag = memory_format_tag::nhwc; - } - *dt = to_dpct_library_data_t(_desc.get_data_type(), block_size); - *ndims = _desc.get_ndims(); - for (int index = 0; index < requested_ndims; index++) { - dims[index] = adims[index]; - } -} - -inline -void engine_ext::get_rnn_configuration(const ::dnnl::memory::desc &desc, - rnn_direction direction, rnn_mode mode, - dpct::library_data_t dt, int hidden_size, - ::dnnl::memory::data_type *dnnl_dt, - ::dnnl::memory::format_tag *tag, - int *projection_size, int *output_size, - int *seq_length, int *batch_size, - int *direction_num, int *gate_num) { - if (!desc.is_zero()) { - auto dims = desc.get_dims(); - auto strides = desc.get_strides(); - if (strides[0] >= strides[1]) { - *tag = ::dnnl::memory::format_tag::tnc; - *seq_length = dims[0]; - *batch_size = dims[1]; - } else { - *tag = ::dnnl::memory::format_tag::ntc; - *seq_length = dims[1]; - *batch_size = dims[0]; - } - } - if (direction == rnn_direction::bidirectional) { - *direction_num = 2; - } else { - *direction_num = 1; - } - if (mode == rnn_mode::lstm) { - *gate_num = 4; - } else if (mode == rnn_mode::gru) { - *gate_num = 3; - } else { - *gate_num = 1; - } - if (*projection_size != hidden_size) { - *output_size = *projection_size; - } else { - *projection_size = 0; - *output_size = hidden_size; - } - *dnnl_dt = memory_desc_ext::to_dnnl_data_type(dt); -} - -inline -void *engine_ext::allocate(const memory_desc_ext &data_desc, int count) { - return allocate(data_desc.get_size() * count); -} - -inline -void *engine_ext::allocate(size_t size) { - auto &Info = get_internal_resource(_q)->binfo; - uint8_t *result = Info.buffer + Info.usage; - Info.usage += size; - return result; -} - -inline -void engine_ext::transform_no_zero(const memory_desc_ext &desc, void *src, void *dst) { - ::dnnl::memory::data_type dt = desc.get_desc().get_data_type(); - size_t element_num = desc.get_element_num(); - switch (dt) { - case ::dnnl::memory::data_type::f32: - transform_no_zero_with_type(_q, src, dst, element_num); - break; - case ::dnnl::memory::data_type::f16: - transform_no_zero_with_type(_q, src, dst, element_num); - break; - case ::dnnl::memory::data_type::s32: - transform_no_zero_with_type(_q, src, dst, element_num); - break; - case ::dnnl::memory::data_type::s8: - transform_no_zero_with_type(_q, src, dst, element_num); - break; - case ::dnnl::memory::data_type::u8: - transform_no_zero_with_type(_q, src, dst, element_num); - break; - default: - throw std::runtime_error("transform_no_zero: unsupported data type."); - } -} - -inline -::dnnl::memory::desc -engine_ext::get_group_weight_desc(int group_count, - const memory_desc_ext &weight_desc) { - if (group_count == 1) { - return weight_desc.get_desc(); - } - auto help_weight_desc = weight_desc.get_desc(); - int ndims = help_weight_desc.get_ndims(); - if (!help_weight_desc.get_inner_blks().empty()) { - throw std::runtime_error("get_group_weight_desc: group convolution with " - "blocked weight memory unimplemented."); - } - std::vector new_size; - auto old_size = weight_desc.get_dims(); - new_size.push_back(group_count); - new_size.push_back(old_size[0] / group_count); - for (int index = 1; index < old_size.size(); index++) { - new_size.push_back(old_size[index]); - } - std::vector strides = help_weight_desc.get_strides(); - ::dnnl::memory::format_tag tag; - bool is_nhwc = (strides[1] == 1 && old_size[1] != 1); - - if (ndims == 4) { - if (is_nhwc) { - tag = ::dnnl::memory::format_tag::gohwi; - } else { - tag = ::dnnl::memory::format_tag::goihw; - } - } else if (ndims == 5) { - if (is_nhwc) { - tag = ::dnnl::memory::format_tag::godhwi; - } else { - tag = ::dnnl::memory::format_tag::goidhw; - } - } - - help_weight_desc = - ::dnnl::memory::desc(new_size, weight_desc.get_desc().get_data_type(), tag); - return help_weight_desc; -} - -inline -::dnnl::memory::desc engine_ext::compress_spatial_dimensions_to_channel( - const ::dnnl::memory::desc &desc) { - int ndims = desc.get_ndims(); - auto dims = desc.get_dims(); - auto inner_blks = desc.get_inner_blks(); - assert(ndims >= 4 && "ndims is at least 4."); - std::vector compressed_dims(ndims); - compressed_dims[0] = dims[0]; - compressed_dims[1] = dims[1]; - for (int index = 2; index < ndims; index++) { - compressed_dims[1] = compressed_dims[1] * dims[index]; - compressed_dims[index] = 1; - } - if (!inner_blks.empty() && inner_blks[0] == 4) { - return ::dnnl::memory::desc(compressed_dims, desc.get_data_type(), - ::dnnl::memory::format_tag::nChw4c); - } else if (!inner_blks.empty() && inner_blks[0] == 32) { - return ::dnnl::memory::desc(compressed_dims, desc.get_data_type(), - ::dnnl::memory::format_tag::nChw32c); - } - std::vector strides(ndims, 1); - strides[0] = compressed_dims[1]; - - return ::dnnl::memory::desc(compressed_dims, desc.get_data_type(), strides); -} - -inline -::dnnl::memory::desc -engine_ext::get_bn_scale_bias_mean_var_desc(const ::dnnl::memory::desc &desc, - batch_normalization_mode mode) { - int ndims = desc.get_ndims(); - auto dims = desc.get_dims(); - assert(ndims >= 4 && "ndims is at least 4."); - int channel_num = 1; - if (mode == batch_normalization_mode::spatial) { - channel_num = dims[1]; - } else { - for (int index = 1; index < ndims; index++) { - channel_num = channel_num * dims[index]; - } - } - return ::dnnl::memory::desc({channel_num}, desc.get_data_type(), - ::dnnl::memory::format_tag::a); -} - -inline -::dnnl::memory::desc engine_ext::transfer_memory_desc_to_channel_major_format( - const ::dnnl::memory::desc &desc) { - if (!desc.get_inner_blks().empty()) { - return desc; - } - int ndims = desc.get_ndims(); - auto dims = desc.get_dims(); - if (ndims == 4) { - return ::dnnl::memory::desc(dims, desc.get_data_type(), - ::dnnl::memory::format_tag::nchw); - } - return ::dnnl::memory::desc(dims, desc.get_data_type(), - ::dnnl::memory::format_tag::ncdhw); -} - -/// If the alpha = 0 and beta = 1, then the destination (dst = alpha * out + -/// beta * prior_dst) have no change. In this case this function returns true -/// means the operation can exit directly. -inline -bool engine_ext::scale_parameter_preprocess( - const std::vector &args) { - bool direct_exit = true; - for (auto &arg : args) { - if (arg._alpha == 0.f) { - if (arg._beta != 1.f) { - async_scale(arg._beta, arg._desc, arg._data); - } - } else { - direct_exit = false; - } - } - return direct_exit; -} - -inline -void engine_ext::derive_batch_normalization_memory_desc( - memory_desc_ext &scale_bias_desc, memory_desc_ext &mean_var_desc, - const memory_desc_ext &src_desc, batch_normalization_mode mode) { - derive_batch_normalization_memory_desc(scale_bias_desc, src_desc, mode); - derive_batch_normalization_memory_desc(mean_var_desc, src_desc, mode); -} - -inline -void engine_ext::derive_batch_normalization_memory_desc( - memory_desc_ext &desc, const memory_desc_ext &src_desc, - batch_normalization_mode mode) { - int src_ndims = src_desc.get_desc().get_ndims(); - auto inner_blks = src_desc.get_desc().get_inner_blks(); - if (src_desc.get_desc().get_ndims() != 4 || - src_desc.get_desc().get_ndims() != 5) { - throw std::runtime_error("derive_batch_normalization_memory_desc: only 4d " - "and 5d memory descriptor supported."); - } - std::vector dims = src_desc.get_dims(); - dims[0] = 1; - if (mode == batch_normalization_mode::spatial) { - dims[2] = 1; - dims[3] = 1; - if (src_ndims == 5) { - dims[4] = 1; - } - } - auto data_type = src_desc.get_desc().get_data_type(); - if (data_type == ::dnnl::memory::data_type::f16) { - data_type = ::dnnl::memory::data_type::f32; - } - if (!inner_blks.empty() && inner_blks[0] == 4) { - desc.set_desc(::dnnl::memory::desc(dims, data_type, - ::dnnl::memory::format_tag::nChw4c)); - } else if (!inner_blks.empty() && inner_blks[0] == 32) { - desc.set_desc(::dnnl::memory::desc(dims, data_type, - ::dnnl::memory::format_tag::nChw32c)); - } else { - if (src_ndims == 4) { - desc.set_desc(::dnnl::memory::desc(dims, data_type, - ::dnnl::memory::format_tag::nchw)); - } else { - desc.set_desc(::dnnl::memory::desc(dims, data_type, - ::dnnl::memory::format_tag::ncdhw)); - } - } -} - -template -sycl::event engine_ext::execute_primitive( - const std::pair - &primitive, - const std::vector &output_args) { - std::vector caches; - int output_arg_num = output_args.size(); - for (int i = 0; i < output_arg_num; i++) { - if (output_args[i]._beta != 0.f) { - auto cache = allocate(output_args[i]._desc); - caches.push_back(cache); - (*primitive.second.args)[output_args[i]._name].set_data_handle(cache); - } - } - - auto e = ::dnnl::sycl_interop::execute( - *(static_cast(primitive.second.primitive)), *_s, - *primitive.second.args); - _primitive_cache.put( - primitive.first, primitive.second.primitive, primitive.second.args, - [](::dnnl::primitive *p) { delete static_cast(p); }, e, - _q); - int cache_index = 0; - for (int i = 0; i < output_arg_num; i++) { - if (output_args[i]._beta != 0.f) { - e = async_sum(output_args[i]._alpha, output_args[i]._desc, - caches[cache_index++], output_args[i]._beta, - output_args[i]._desc, output_args[i]._data); - } else { - if (output_args[i]._alpha != 1.f) { - e = async_scale(output_args[i]._alpha, output_args[i]._desc, - output_args[i]._data); - } - } - } - return e; -} - -inline -::dnnl::memory::desc engine_ext::bn_reorder_memory_to_channel_major_format( - bool is_input, ::dnnl::memory::desc &desc, void *src, void **cache) { - ::dnnl::memory::desc result; - result = transfer_memory_desc_to_channel_major_format(desc); - if ((result != desc) || !src) { - *cache = allocate(desc); - if (is_input && src) { - async_reorder(1.f, desc, src, 0.f, result, *cache); - } - } - return result; -} - -inline -sycl::event engine_ext::batch_normalization_backward_internal( - batch_normalization_mode mode, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta_data, - const memory_desc_ext &diff_src_desc, void *diff_src, float alpha_param, - const memory_desc_ext &diff_scale_bias_desc, void *scale, void *bias, - float beta_param, void *diff_scale, void *diff_bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var) { - if (scale_parameter_preprocess( - {{alpha_data, beta_data, diff_src_desc, diff_src}, - {alpha_param, beta_param, diff_scale_bias_desc, diff_scale}, - {alpha_param, beta_param, diff_scale_bias_desc, diff_bias}})) { - return sycl::event(); - } - - void *reordered_src = nullptr, *reordered_diff_dst = nullptr, - *reordered_diff_src = nullptr, *reordered_scale = nullptr, - *reordered_bias = nullptr, *reordered_diff_scale = nullptr, - *reordered_diff_bias = nullptr, *reordered_saved_mean = nullptr, - *reordered_saved_var = nullptr; - - ::dnnl::memory::desc help_src_desc = src_desc.get_desc(); - ::dnnl::memory::desc help_diff_dst_desc = diff_dst_desc.get_desc(); - ::dnnl::memory::desc help_diff_src_desc = diff_src_desc.get_desc(); - ::dnnl::memory::desc help_diff_scale_bias_desc = - diff_scale_bias_desc.get_desc(); - ::dnnl::memory::desc help_mean_var_desc = mean_var_desc.get_desc(); - ::dnnl::memory::desc actual_diff_src_desc = help_diff_src_desc; - ::dnnl::memory::desc actual_diff_scale_bias_desc = help_diff_scale_bias_desc; - enter_primitive( - help_diff_scale_bias_desc.get_size() * 14 + help_src_desc.get_size() * 2 + - help_diff_dst_desc.get_size() * 7 + help_diff_src_desc.get_size() * 5 + - help_mean_var_desc.get_size() * 13); - if (mode == batch_normalization_mode::per_activation) { - help_src_desc = bn_reorder_memory_to_channel_major_format(true, help_src_desc, src, - &reordered_src); - help_diff_dst_desc = bn_reorder_memory_to_channel_major_format( - true, help_diff_dst_desc, diff_dst, &reordered_diff_dst); - help_diff_src_desc = bn_reorder_memory_to_channel_major_format( - false, help_diff_src_desc, diff_src, &reordered_diff_src); - actual_diff_src_desc = help_diff_src_desc; - help_diff_scale_bias_desc = bn_reorder_memory_to_channel_major_format( - true, help_diff_scale_bias_desc, scale, &reordered_scale); - actual_diff_scale_bias_desc = help_diff_scale_bias_desc; - if (bias) { - bn_reorder_memory_to_channel_major_format(true, help_diff_scale_bias_desc, bias, - &reordered_bias); - } - bn_reorder_memory_to_channel_major_format(false, help_diff_scale_bias_desc, - diff_scale, &reordered_diff_scale); - bn_reorder_memory_to_channel_major_format(false, help_diff_scale_bias_desc, - diff_bias, &reordered_diff_bias); - - help_mean_var_desc = bn_reorder_memory_to_channel_major_format( - true, help_mean_var_desc, saved_mean, &reordered_saved_mean); - bn_reorder_memory_to_channel_major_format(true, help_mean_var_desc, saved_var, - &reordered_saved_var); - help_src_desc = compress_spatial_dimensions_to_channel(help_src_desc); - help_diff_src_desc = - compress_spatial_dimensions_to_channel(help_diff_src_desc); - help_diff_dst_desc = - compress_spatial_dimensions_to_channel(help_diff_dst_desc); - } else { - if ((help_src_desc != help_diff_dst_desc) || - (help_src_desc != help_diff_src_desc) || - (help_diff_dst_desc != help_diff_src_desc)) { - help_src_desc = bn_reorder_memory_to_channel_major_format( - true, help_src_desc, src, &reordered_src); - help_diff_dst_desc = bn_reorder_memory_to_channel_major_format( - true, help_diff_dst_desc, diff_dst, &reordered_diff_dst); - help_diff_src_desc = bn_reorder_memory_to_channel_major_format( - false, help_diff_src_desc, diff_src, &reordered_diff_src); - actual_diff_src_desc = help_diff_src_desc; - } - } - - help_diff_scale_bias_desc = - get_bn_scale_bias_mean_var_desc(help_diff_scale_bias_desc, mode); - help_mean_var_desc = - get_bn_scale_bias_mean_var_desc(help_mean_var_desc, mode); - - auto forward_primitive = - create_primitive_desc<::dnnl::batch_normalization_forward>( - ::dnnl::prop_kind::forward_training, help_src_desc, - help_diff_dst_desc, epsilon, - ::dnnl::normalization_flags::use_scale | - ::dnnl::normalization_flags::use_shift); - auto primitive_args = - create_primitive_args_or_get<::dnnl::batch_normalization_backward>( - ::dnnl::prop_kind::backward, help_diff_src_desc, help_diff_dst_desc, - help_src_desc, epsilon, - ::dnnl::normalization_flags::use_scale | - ::dnnl::normalization_flags::use_shift, forward_primitive); - - void *dst_cache = nullptr; - if (!saved_mean && !saved_var) { - dst_cache = allocate(diff_dst_desc); - if (!reordered_saved_mean) { - reordered_saved_mean = allocate(mean_var_desc); - } - if (!reordered_saved_var) { - reordered_saved_var = allocate(mean_var_desc); - } - if (!bias) { - _q->fill(reordered_bias, 0, diff_scale_bias_desc.get_size()); - } - - batch_normalization_forward_internal( - true, mode, epsilon, 0.f, 1.f, src_desc, src, 0.f, diff_dst_desc, - dst_cache, diff_scale_bias_desc, scale, bias ? bias : reordered_bias, - mean_var_desc, reordered_saved_mean, reordered_saved_var, nullptr, - nullptr); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, help_src_desc, - reordered_src ? reordered_src : src); - insert_arg(primitive_args.second.args, DNNL_ARG_SCALE, - help_diff_scale_bias_desc, - reordered_scale ? reordered_scale : scale); - insert_arg(primitive_args.second.args, DNNL_ARG_MEAN, help_mean_var_desc, - reordered_saved_mean ? reordered_saved_mean : saved_mean); - insert_arg(primitive_args.second.args, DNNL_ARG_VARIANCE, help_mean_var_desc, - reordered_saved_var ? reordered_saved_var : saved_var); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, help_diff_src_desc, - reordered_diff_dst ? reordered_diff_dst : diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, help_diff_src_desc, - reordered_diff_src ? reordered_diff_src : diff_src); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SCALE, - help_diff_scale_bias_desc, - reordered_diff_scale ? reordered_diff_scale : diff_scale); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SHIFT, - help_diff_scale_bias_desc, - reordered_diff_bias ? reordered_diff_bias : diff_bias); - - sycl::event e = execute_primitive<::dnnl::batch_normalization_backward>( - primitive_args, - {{alpha_data, beta_data, DNNL_ARG_DIFF_SRC, help_diff_src_desc, - reordered_diff_src ? reordered_diff_src : diff_src}, - {alpha_param, beta_param, DNNL_ARG_DIFF_SCALE, help_diff_scale_bias_desc, - reordered_diff_scale ? reordered_diff_scale : diff_scale}, - {alpha_param, beta_param, DNNL_ARG_DIFF_SHIFT, help_diff_scale_bias_desc, - reordered_diff_bias ? reordered_diff_bias : diff_bias}}); - if (actual_diff_src_desc != diff_src_desc.get_desc() && reordered_diff_src) { - e = async_reorder(1.f, actual_diff_src_desc, reordered_diff_src, 0.f, - diff_src_desc, diff_src); - } - if (actual_diff_scale_bias_desc != diff_scale_bias_desc.get_desc() && - reordered_diff_scale && reordered_diff_bias) { - async_reorder(1.f, actual_diff_scale_bias_desc, reordered_diff_scale, 0.f, - diff_scale_bias_desc, diff_scale); - e = async_reorder(1.f, actual_diff_scale_bias_desc, reordered_diff_bias, 0.f, - diff_scale_bias_desc, diff_bias); - } - return exit_primitive(e); -} - -inline -sycl::event engine_ext::batch_normalization_forward_internal( - bool is_infer, batch_normalization_mode mode, float epsilon, float factor, - float alpha, const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var, - void *running_mean, void *running_var) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - enter_primitive(src_desc.get_size() + 5 * dst_desc.get_size() + - scale_bias_desc.get_size() * 2 + - mean_var_desc.get_size() * 9); - void *reordered_src = nullptr, *reordered_dst = nullptr, - *reordered_scale = nullptr, *reordered_bias = nullptr, - *reordered_saved_mean = nullptr, *reordered_saved_var = nullptr; - ::dnnl::memory::desc help_src_desc = src_desc.get_desc(); - ::dnnl::memory::desc help_dst_desc = dst_desc.get_desc(); - ::dnnl::memory::desc help_scale_bias_desc = scale_bias_desc.get_desc(); - ::dnnl::memory::desc help_mean_var_desc = mean_var_desc.get_desc(); - ::dnnl::memory::desc actual_dst_desc = help_dst_desc; - ::dnnl::memory::desc actual_mean_var_desc = help_mean_var_desc; - - if (mode == batch_normalization_mode::per_activation) { - help_src_desc = bn_reorder_memory_to_channel_major_format(true, help_src_desc, src, - &reordered_src); - help_dst_desc = bn_reorder_memory_to_channel_major_format( - false, help_dst_desc, dst, &reordered_dst); - actual_dst_desc = help_dst_desc; - help_scale_bias_desc = bn_reorder_memory_to_channel_major_format( - true, help_scale_bias_desc, scale, &reordered_scale); - bn_reorder_memory_to_channel_major_format(true, help_scale_bias_desc, bias, - &reordered_bias); - help_mean_var_desc = bn_reorder_memory_to_channel_major_format( - is_infer, help_mean_var_desc, saved_mean, - &reordered_saved_mean); - actual_mean_var_desc = help_mean_var_desc; - bn_reorder_memory_to_channel_major_format(is_infer, - help_mean_var_desc, saved_var, - &reordered_saved_var); - help_src_desc = compress_spatial_dimensions_to_channel(help_src_desc); - help_dst_desc = compress_spatial_dimensions_to_channel(help_dst_desc); - } else { - if (help_src_desc != help_dst_desc) { - help_src_desc = bn_reorder_memory_to_channel_major_format( - true, help_src_desc, src, &reordered_src); - help_dst_desc = bn_reorder_memory_to_channel_major_format( - false, help_dst_desc, dst, &reordered_dst); - actual_dst_desc = help_dst_desc; - } - } - help_scale_bias_desc = - get_bn_scale_bias_mean_var_desc(help_scale_bias_desc, mode); - help_mean_var_desc = - get_bn_scale_bias_mean_var_desc(help_mean_var_desc, mode); - - ::dnnl::prop_kind kind; - ::dnnl::normalization_flags flag = ::dnnl::normalization_flags::use_scale | - ::dnnl::normalization_flags::use_shift; - if (is_infer) { - kind = ::dnnl::prop_kind::forward_inference; - flag = ::dnnl::normalization_flags::use_global_stats | flag; - } else { - kind = ::dnnl::prop_kind::forward_training; - } - auto primitive_args = - create_primitive_args_or_get<::dnnl::batch_normalization_forward>( - kind, help_src_desc, help_dst_desc, epsilon, flag); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, help_src_desc, - reordered_src ? reordered_src : src); - insert_arg(primitive_args.second.args, DNNL_ARG_SCALE, help_scale_bias_desc, - reordered_scale ? reordered_scale : scale); - insert_arg(primitive_args.second.args, DNNL_ARG_SHIFT, help_scale_bias_desc, - reordered_bias ? reordered_bias : bias); - insert_arg(primitive_args.second.args, DNNL_ARG_MEAN, help_mean_var_desc, - reordered_saved_mean ? reordered_saved_mean - : saved_mean); - insert_arg(primitive_args.second.args, DNNL_ARG_VARIANCE, help_mean_var_desc, - reordered_saved_var ? reordered_saved_var - : saved_var); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, help_dst_desc, - reordered_dst ? reordered_dst : dst); - sycl::event e = execute_primitive<::dnnl::batch_normalization_forward>(primitive_args, - {{alpha, beta, DNNL_ARG_DST, help_dst_desc, - reordered_dst ? reordered_dst : dst}}); - - if (!is_infer && running_var) { - auto src_ndim = src_desc.get_desc().get_ndims(); - auto src_dims = src_desc.get_dims(); - int element_num = src_dims[0]; - if (mode == batch_normalization_mode::spatial) { - for (int index = 2; index < src_ndim; index++) { - element_num *= src_dims[index]; - } - } - float unbias_factor = element_num / (element_num - 1.f); - async_scale(1.f - factor, mean_var_desc, running_var); - e = async_sum(factor * unbias_factor, mean_var_desc, - reordered_saved_var ? reordered_saved_var : saved_var, - 1.f, mean_var_desc, running_var); - } - if (!is_infer && running_mean) { - e = async_sum(factor, mean_var_desc, - reordered_saved_mean ? reordered_saved_mean : saved_mean, - (1.f - factor), mean_var_desc, running_mean); - } - if (reordered_dst && (actual_dst_desc != dst_desc.get_desc())) { - e = async_reorder(1.f, actual_dst_desc, reordered_dst, 0.f, dst_desc, dst); - } - if (!is_infer && reordered_saved_mean && reordered_saved_var && saved_mean && - saved_var && (actual_mean_var_desc != mean_var_desc.get_desc())) { - e = async_reorder(1.f, actual_mean_var_desc, reordered_saved_mean, 0.f, - mean_var_desc, saved_mean); - e = async_reorder(1.f, actual_mean_var_desc, reordered_saved_var, 0.f, - mean_var_desc, saved_var); - } - return exit_primitive(e); -} - -inline -sycl::event engine_ext::rnn_forward_internal( - const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, void *src, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &iter_desc, void *src_iter, void *dst_iter, - const memory_desc_ext &iter_c_desc, void *src_iter_c, void *dst_iter_c, - size_t weight_size, void *weight, size_t workspace_size, void *workspace, - size_t scratchpad_size, void *scratchpad, bool is_get_execution_args, - size_t *weight_size_query, size_t *workspace_size_query, - size_t *scratchpad_size_query) { - ::dnnl::memory::data_type src_dt; - ::dnnl::memory::format_tag src_format_tag; - rnn_mode mode; - rnn_bias_mode bias_mode; - rnn_direction direction; - dpct::library_data_t dt; - int direction_num = 1, input_size = 0, hidden_size = 0, projection_size = 0, - layer_size = 0, gate_num = 1, output_size = 0, data_type_size = 0, - seq_length = 1, batch_size = 1; - std::vector data = {src, dst, src_iter, dst_iter, - src_iter_c, dst_iter_c, weight, workspace, - scratchpad}; - std::vector offset(6, 0); - void *input_layer_cache = nullptr, *hidden_layer_cache = nullptr; - sycl::event e; - enter_primitive(src_desc.get_size() * 2); - desc.get(&mode, &bias_mode, &direction, &dt, &input_size, &hidden_size, - &projection_size, &layer_size); - - get_rnn_configuration(src_desc.get_desc(), direction, mode, dt, hidden_size, - &src_dt, &src_format_tag, &projection_size, - &output_size, &seq_length, &batch_size, &direction_num, - &gate_num); - - if (direction == rnn_direction::bidirectional) { - // Here to combine the oneDNN bidirectional_sum and - // bidirectional_concat config, so call execute_rnn_forward_primitive - // twice. - if (layer_size > 1) { - if (!is_get_execution_args) { - input_layer_cache = allocate(src_desc); - hidden_layer_cache = allocate(src_desc); - _q->memcpy(input_layer_cache, src, src_desc.get_size()); - } - data[0] = input_layer_cache; - data[1] = hidden_layer_cache; - e = execute_rnn_forward_primitive( - mode, kind, ::dnnl::rnn_direction::bidirectional_sum, bias_mode, - src_dt, src_format_tag, seq_length, batch_size, output_size, - output_size, 1, direction_num, hidden_size, gate_num, projection_size, - data, offset, layer_size - 1, weight_size_query, workspace_size_query, - scratchpad_size_query); - data[0] = - ((layer_size - 1) % 2 == 0) ? input_layer_cache : hidden_layer_cache; - data[1] = dst; - } - e = execute_rnn_forward_primitive( - mode, kind, ::dnnl::rnn_direction::bidirectional_concat, bias_mode, - src_dt, src_format_tag, seq_length, batch_size, output_size, - 2 * output_size, 1, direction_num, hidden_size, gate_num, - projection_size, data, offset, 1, weight_size_query, - workspace_size_query, scratchpad_size_query); - } else { - e = execute_rnn_forward_primitive( - mode, kind, ::dnnl::rnn_direction::unidirectional_left2right, bias_mode, - src_dt, src_format_tag, seq_length, batch_size, output_size, - output_size, layer_size, direction_num, hidden_size, gate_num, - projection_size, data, offset, 1, weight_size_query, - workspace_size_query, scratchpad_size_query); - } - - return exit_primitive(e); -} - -inline -sycl::event engine_ext::execute_rnn_forward_primitive( - rnn_mode mode, ::dnnl::prop_kind kind, ::dnnl::rnn_direction direction, - rnn_bias_mode bias_mode, ::dnnl::memory::data_type dt, - ::dnnl::memory::format_tag tag, int seq_length, int batch_size, int src_c, - int dst_c, int layer_size, int direction_num, int hidden_size, int gate_num, - int projection_size, std::vector &data, std::vector &offset, - int iter_num, size_t *weight_size, size_t *workspace_size, - size_t *scratchpad_size) { - - sycl::event e; - ::dnnl::primitive *p = nullptr; - std::unordered_map *args = nullptr; - detail::primitive_cache_key_type key; - std::unordered_map *execution_args; - ::dnnl::memory::desc bias_desc( - {layer_size, direction_num, gate_num, hidden_size}, dt, - ::dnnl::memory::format_tag::ldgo); - ::dnnl::memory::desc weight_layer_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldigo); - ::dnnl::memory::desc weight_iter_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldigo); - ::dnnl::memory::desc projection_desc; - if (projection_size) { - projection_desc = ::dnnl::memory::desc( - {layer_size, direction_num, hidden_size, projection_size}, dt, - ::dnnl::memory::format_tag::ldio); - } - - if (weight_size) { - *weight_size += - (weight_layer_desc.get_size() + weight_iter_desc.get_size() + - projection_desc.get_size() + bias_desc.get_size()) * - iter_num; - return e; - } - - ::dnnl::memory::desc src_desc({seq_length, batch_size, src_c}, dt, tag); - ::dnnl::memory::desc dst_desc({seq_length, batch_size, dst_c}, dt, tag); - ::dnnl::memory::desc iter_desc( - {layer_size, direction_num, batch_size, - projection_size ? projection_size : hidden_size}, - dt, ::dnnl::memory::format_tag::ldnc); - ::dnnl::memory::desc iter_c_desc( - {layer_size, direction_num, batch_size, hidden_size}, dt, - ::dnnl::memory::format_tag::ldnc); - - ::dnnl::memory::desc workspace_desc; - ::dnnl::memory::desc scratchpad_desc; - ::dnnl::primitive_attr attr; - attr.set_scratchpad_mode(::dnnl::scratchpad_mode::user); - - if (mode == rnn_mode::vanilla_relu || mode == rnn_mode::vanilla_tanh) { - auto primitive = create_primitive_args_or_get<::dnnl::vanilla_rnn_forward>( - kind, - mode == rnn_mode::vanilla_relu ? ::dnnl::algorithm::eltwise_relu - : ::dnnl::algorithm::eltwise_tanh, - direction, src_desc, iter_desc, weight_layer_desc, weight_iter_desc, - bias_desc, dst_desc, iter_desc, attr); - - auto pd = get_primitive_desc<::dnnl::vanilla_rnn_forward>( - primitive.second.primitive); - - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - if (workspace_size && scratchpad_size) { - *workspace_size += workspace_desc.get_size() * iter_num; - *scratchpad_size = scratchpad_desc.get_size() > *scratchpad_size - ? scratchpad_desc.get_size() - : *scratchpad_size; - } else { - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } - } else if (mode == rnn_mode::gru) { - auto primitive = create_primitive_args_or_get<::dnnl::gru_forward>( - kind, direction, src_desc, iter_desc, weight_layer_desc, - weight_iter_desc, bias_desc, dst_desc, iter_desc, attr); - - auto pd = - get_primitive_desc<::dnnl::gru_forward>(primitive.second.primitive); - - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - if (workspace_size && scratchpad_size) { - *workspace_size += workspace_desc.get_size() * iter_num; - *scratchpad_size = scratchpad_desc.get_size() > *scratchpad_size - ? scratchpad_desc.get_size() - : *scratchpad_size; - } else { - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } - } else if (mode == rnn_mode::lstm) { - auto primitive = create_primitive_args_or_get<::dnnl::lstm_forward>( - kind, direction, src_desc, iter_desc, iter_c_desc, weight_layer_desc, - weight_iter_desc, ::dnnl::memory::desc(), projection_desc, bias_desc, - dst_desc, iter_desc, iter_c_desc, attr); - - auto pd = - get_primitive_desc<::dnnl::lstm_forward>(primitive.second.primitive); - - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - if (workspace_size && scratchpad_size) { - *workspace_size += workspace_desc.get_size() * iter_num; - *scratchpad_size = scratchpad_desc.get_size() > *scratchpad_size - ? scratchpad_desc.get_size() - : *scratchpad_size; - } else { - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } - } - - for (int i = 0; i < iter_num; i++) { - void *in_cache = data[0], *out_cache = data[1], *dst_iter_c_cache = nullptr, - *dst_iter_cache = ((uint8_t *)(data[3]) + offset[1]); - if (mode == rnn_mode::lstm) { - dst_iter_c_cache = (uint8_t *)(data[4]) + offset[2]; - } - if (!workspace_size) { - insert_arg(args, DNNL_ARG_SRC_LAYER, src_desc, data[0]); - insert_arg(args, DNNL_ARG_DST_LAYER, dst_desc, data[1]); - insert_arg(args, DNNL_ARG_SCRATCHPAD, scratchpad_desc, data[8]); - auto insert_rnn_arg = [&](int arg_name, ::dnnl::memory::desc &d, void *data, - int &offset) { - insert_arg(args, arg_name, d, (uint8_t *)data + offset); - offset += d.get_size(); - }; - insert_rnn_arg(DNNL_ARG_SRC_ITER, iter_desc, data[2], offset[0]); - insert_rnn_arg(DNNL_ARG_DST_ITER, iter_desc, data[3], offset[1]); - - if (mode == rnn_mode::lstm) { - insert_rnn_arg(DNNL_ARG_SRC_ITER_C, iter_c_desc, data[4], offset[2]); - insert_rnn_arg(DNNL_ARG_DST_ITER_C, iter_c_desc, data[5], offset[3]); - } - insert_rnn_arg(DNNL_ARG_WEIGHTS_LAYER, weight_layer_desc, data[6], - offset[4]); - insert_rnn_arg(DNNL_ARG_WEIGHTS_ITER, weight_iter_desc, data[6], offset[4]); - if (projection_size) { - insert_rnn_arg(DNNL_ARG_WEIGHTS_PROJECTION, projection_desc, data[6], - offset[4]); - } - if (bias_mode == rnn_bias_mode::none) { - _q->memset((uint8_t *)(data[6]) + offset[4], 0, bias_desc.get_size()); - } - insert_rnn_arg(DNNL_ARG_BIAS, bias_desc, data[6], offset[4]); - if (kind == ::dnnl::prop_kind::forward_training) { - insert_rnn_arg(DNNL_ARG_WORKSPACE, workspace_desc, data[7], offset[5]); - } - if (mode == rnn_mode::vanilla_relu || mode == rnn_mode::vanilla_tanh) { - execute_primitive<::dnnl::vanilla_rnn_forward>( - {key, {static_cast<::dnnl::vanilla_rnn_forward *>(p), args}}); - } else if (mode == rnn_mode::gru) { - execute_primitive<::dnnl::gru_forward>( - {key, {static_cast<::dnnl::gru_forward *>(p), args}}); - } else if (mode == rnn_mode::lstm) { - execute_primitive<::dnnl::lstm_forward>( - {key, {static_cast<::dnnl::lstm_forward *>(p), args}}); - } - if (i != iter_num - 1) { - std::swap(data[0], data[1]); - } - } - if (kind == ::dnnl::prop_kind::forward_training) { - if (workspace_size) { - *workspace_size += - (src_desc.get_size() + dst_desc.get_size() + iter_desc.get_size()); - if (mode == rnn_mode::lstm) { - *workspace_size += iter_c_desc.get_size(); - } - } else { - _q->memcpy((uint8_t *)(data[7]) + offset[5], in_cache, - src_desc.get_size()); - offset[5] += src_desc.get_size(); - _q->memcpy((uint8_t *)(data[7]) + offset[5], out_cache, - dst_desc.get_size()); - offset[5] += dst_desc.get_size(); - _q->memcpy((uint8_t *)(data[7]) + offset[5], dst_iter_cache, - iter_desc.get_size()); - offset[5] += iter_desc.get_size(); - if (mode == rnn_mode::lstm) { - _q->memcpy((uint8_t *)(data[7]) + offset[5], dst_iter_c_cache, - iter_c_desc.get_size()); - offset[5] += iter_c_desc.get_size(); - } - } - } - } - return e; -} - -inline -sycl::event engine_ext::execute_rnn_backward_primitive( - rnn_mode mode, ::dnnl::rnn_direction direction, rnn_bias_mode bias_mode, - ::dnnl::memory::data_type dt, ::dnnl::memory::format_tag tag, - int seq_length, int batch_size, int src_c, int dst_c, int layer_size, - int direction_num, int hidden_size, int gate_num, int projection_size, - std::vector &data, std::vector &offset, int iter_num) { - - sycl::event e; - ::dnnl::primitive *p = nullptr; - std::unordered_map *args = nullptr; - detail::primitive_cache_key_type key; - ::dnnl::prop_kind fkind = ::dnnl::prop_kind::forward_training; - ::dnnl::prop_kind bkind = ::dnnl::prop_kind::backward; - ::dnnl::memory::desc bias_desc( - {layer_size, direction_num, gate_num, hidden_size}, dt, - ::dnnl::memory::format_tag::ldgo); - ::dnnl::memory::desc weight_layer_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldigo); - ::dnnl::memory::desc weight_iter_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldigo); - ::dnnl::memory::desc diff_weight_layer_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldgoi); - ::dnnl::memory::desc diff_weight_iter_desc( - {layer_size, direction_num, - projection_size ? projection_size : hidden_size, gate_num, hidden_size}, - dt, ::dnnl::memory::format_tag::ldgoi); - ::dnnl::memory::desc projection_desc, diff_projection_desc; - if (projection_size) { - projection_desc = ::dnnl::memory::desc( - {layer_size, direction_num, hidden_size, projection_size}, dt, - ::dnnl::memory::format_tag::ldio); - diff_projection_desc = ::dnnl::memory::desc( - {layer_size, direction_num, hidden_size, projection_size}, dt, - ::dnnl::memory::format_tag::ldoi); - } - - ::dnnl::memory::desc src_desc({seq_length, batch_size, src_c}, dt, tag); - ::dnnl::memory::desc dst_desc({seq_length, batch_size, dst_c}, dt, tag); - ::dnnl::memory::desc iter_desc( - {layer_size, direction_num, batch_size, - projection_size ? projection_size : hidden_size}, - dt, ::dnnl::memory::format_tag::ldnc); - ::dnnl::memory::desc iter_c_desc( - {layer_size, direction_num, batch_size, hidden_size}, dt, - ::dnnl::memory::format_tag::ldnc); - - ::dnnl::memory::desc workspace_desc; - ::dnnl::memory::desc scratchpad_desc; - ::dnnl::primitive_attr attr; - attr.set_scratchpad_mode(::dnnl::scratchpad_mode::user); - - if (mode == rnn_mode::vanilla_relu || mode == rnn_mode::vanilla_tanh) { - auto fpd = create_primitive_desc<::dnnl::vanilla_rnn_forward>( - fkind, - mode == rnn_mode::vanilla_relu ? ::dnnl::algorithm::eltwise_relu - : ::dnnl::algorithm::eltwise_tanh, - direction, src_desc, iter_desc, weight_layer_desc, weight_iter_desc, - bias_desc, dst_desc, iter_desc, attr); - auto primitive = create_primitive_args_or_get<::dnnl::vanilla_rnn_backward>( - bkind, - mode == rnn_mode::vanilla_relu ? ::dnnl::algorithm::eltwise_relu - : ::dnnl::algorithm::eltwise_tanh, - direction, src_desc, iter_desc, diff_weight_layer_desc, - diff_weight_iter_desc, bias_desc, dst_desc, iter_desc, src_desc, - iter_desc, weight_layer_desc, weight_iter_desc, bias_desc, dst_desc, - iter_desc, fpd, attr); - auto pd = get_primitive_desc<::dnnl::vanilla_rnn_backward>( - primitive.second.primitive); - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } else if (mode == rnn_mode::gru) { - auto fpd = create_primitive_desc<::dnnl::gru_forward>( - fkind, direction, src_desc, iter_desc, weight_layer_desc, - weight_iter_desc, bias_desc, dst_desc, iter_desc, attr); - auto primitive = create_primitive_args_or_get<::dnnl::gru_backward>( - bkind, direction, src_desc, iter_desc, diff_weight_layer_desc, - diff_weight_iter_desc, bias_desc, dst_desc, iter_desc, src_desc, - iter_desc, weight_layer_desc, weight_iter_desc, bias_desc, dst_desc, - iter_desc, fpd, attr); - auto pd = - get_primitive_desc<::dnnl::gru_backward>(primitive.second.primitive); - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } else if (mode == rnn_mode::lstm) { - auto fpd = create_primitive_desc<::dnnl::lstm_forward>( - fkind, direction, src_desc, iter_desc, iter_c_desc, weight_layer_desc, - weight_iter_desc, ::dnnl::memory::desc(), projection_desc, bias_desc, - dst_desc, iter_desc, iter_c_desc, attr); - auto primitive = create_primitive_args_or_get<::dnnl::lstm_backward>( - bkind, direction, src_desc, iter_desc, iter_c_desc, - diff_weight_layer_desc, diff_weight_iter_desc, ::dnnl::memory::desc(), - diff_projection_desc, bias_desc, dst_desc, iter_desc, iter_c_desc, - src_desc, iter_desc, iter_c_desc, weight_layer_desc, weight_iter_desc, - ::dnnl::memory::desc(), projection_desc, bias_desc, dst_desc, iter_desc, - iter_c_desc, fpd, attr); - auto pd = - get_primitive_desc<::dnnl::lstm_backward>(primitive.second.primitive); - workspace_desc = pd.workspace_desc(); - scratchpad_desc = pd.scratchpad_desc(); - key = primitive.first; - p = primitive.second.primitive; - args = primitive.second.args; - } - - for (int i = 0; i < iter_num; i++) { - insert_arg(args, DNNL_ARG_DIFF_SRC_LAYER, src_desc, data[8]); - insert_arg(args, DNNL_ARG_DIFF_DST_LAYER, dst_desc, data[9]); - insert_arg(args, DNNL_ARG_SCRATCHPAD, scratchpad_desc, data[15]); - auto insert_rnn_arg = [&](int arg_name, ::dnnl::memory::desc &d, void *data, - int &offset) { - offset += d.get_size(); - insert_arg(args, arg_name, d, (uint8_t *)data - offset); - }; - if (mode == rnn_mode::lstm) { - insert_rnn_arg(DNNL_ARG_DST_ITER_C, iter_c_desc, data[7], offset[0]); - insert_rnn_arg(DNNL_ARG_SRC_ITER_C, iter_c_desc, data[4], offset[2]); - } - insert_rnn_arg(DNNL_ARG_DST_ITER, iter_desc, data[7], offset[0]); - insert_rnn_arg(DNNL_ARG_DST_LAYER, dst_desc, data[7], offset[0]); - insert_rnn_arg(DNNL_ARG_SRC_LAYER, src_desc, data[7], offset[0]); - insert_rnn_arg(DNNL_ARG_WORKSPACE, workspace_desc, data[7], offset[0]); - insert_rnn_arg(DNNL_ARG_SRC_ITER, iter_desc, data[2], offset[1]); - insert_rnn_arg(DNNL_ARG_BIAS, bias_desc, data[6], offset[3]); - if (projection_size) { - insert_rnn_arg(DNNL_ARG_WEIGHTS_PROJECTION, diff_projection_desc, data[6], - offset[3]); - } - insert_rnn_arg(DNNL_ARG_WEIGHTS_ITER, diff_weight_iter_desc, data[6], - offset[3]); - insert_rnn_arg(DNNL_ARG_WEIGHTS_LAYER, diff_weight_layer_desc, data[6], - offset[3]); - insert_rnn_arg(DNNL_ARG_DIFF_SRC_ITER, iter_desc, data[10], offset[4]); - insert_rnn_arg(DNNL_ARG_DIFF_DST_ITER, iter_desc, data[11], offset[5]); - if (mode == rnn_mode::lstm) { - insert_rnn_arg(DNNL_ARG_DIFF_SRC_ITER_C, iter_c_desc, data[12], offset[6]); - insert_rnn_arg(DNNL_ARG_DIFF_DST_ITER_C, iter_c_desc, data[13], offset[7]); - } - insert_rnn_arg(DNNL_ARG_DIFF_BIAS, bias_desc, data[14], offset[8]); - if (bias_mode == rnn_bias_mode::none) { - _q->memset((uint8_t *)(data[14]) - offset[8], 0, bias_desc.get_size()); - } - if (projection_size) { - insert_rnn_arg(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, projection_desc, data[14], - offset[8]); - } - insert_rnn_arg(DNNL_ARG_DIFF_WEIGHTS_ITER, weight_iter_desc, data[14], - offset[8]); - insert_rnn_arg(DNNL_ARG_DIFF_WEIGHTS_LAYER, weight_layer_desc, data[14], - offset[8]); - if (mode == rnn_mode::vanilla_relu || mode == rnn_mode::vanilla_tanh) { - e = execute_primitive<::dnnl::vanilla_rnn_backward>( - {key, {static_cast<::dnnl::vanilla_rnn_backward *>(p), args}}); - } else if (mode == rnn_mode::gru) { - e = execute_primitive<::dnnl::gru_backward>( - {key, {static_cast<::dnnl::gru_backward *>(p), args}}); - } else if (mode == rnn_mode::lstm) { - e = execute_primitive<::dnnl::lstm_backward>( - {key, {static_cast<::dnnl::lstm_backward *>(p), args}}); - } - if (i != iter_num - 1) { - std::swap(data[8], data[9]); - } - } - return e; -} - -#define EMPTY_CACHE_KEY(type) \ - template <> \ - inline void engine_ext::generate_cache_key(std::string & key_buffer, \ - const type &arg) {} - -EMPTY_CACHE_KEY(::dnnl::engine) -EMPTY_CACHE_KEY(::dnnl::convolution_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::eltwise_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::softmax_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::pooling_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::lrn_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::batch_normalization_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::vanilla_rnn_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::lstm_forward::primitive_desc) -EMPTY_CACHE_KEY(::dnnl::gru_forward::primitive_desc) -#undef EMPTY_CACHE_KEY - -template <> -inline void engine_ext::generate_cache_key>( - std::string &key_buffer, const std::vector &vec) { - key_buffer.append((char *)vec.data(), vec.size() * sizeof(float)); -} - -template <> -inline void engine_ext::generate_cache_key<::dnnl::primitive_attr>( - std::string &key_buffer, const ::dnnl::primitive_attr &attr) { - if (!attr) { - return; - } - auto math_mode = (uint8_t)attr.get_fpmath_mode(); - key_buffer.append((char *)&math_mode, sizeof(uint8_t)); -} - -template <> -inline void engine_ext::generate_cache_key<::dnnl::memory::dims>( - std::string &key_buffer, const ::dnnl::memory::dims &dims) { - key_buffer.append((char *)dims.data(), dims.size() * sizeof(int64_t)); -} - -template <> -inline void engine_ext::generate_cache_key<::dnnl::memory::desc>( - std::string &key_buffer, const ::dnnl::memory::desc &desc) { - uint8_t params[3] = {(uint8_t)desc.get_format_kind(), - (uint8_t)desc.get_ndims(), - (uint8_t)desc.get_data_type()}; - generate_cache_key(key_buffer, desc.get_inner_blks()); - generate_cache_key(key_buffer, desc.get_dims()); - generate_cache_key(key_buffer, desc.get_strides()); -} - -template -void engine_ext::generate_cache_key(std::string &key_buffer, const T &arg) { - key_buffer.append((char *)&arg, sizeof(T)); -} - -template -void engine_ext::generate_cache_key(std::string &key_buffer, const T &first_arg, - const args_type &...args) { - generate_cache_key(key_buffer, first_arg); - generate_cache_key(key_buffer, args...); -} - -template -std::pair -engine_ext::create_primitive_args_or_get(args_type &&...args) { - std::string buffer; - buffer.reserve(512); - generate_cache_key(buffer, std::forward(args)...); - buffer.append(std::to_string(_engine_id)); - auto value = _primitive_cache.get(buffer); - primitive_type *p = nullptr; - std::unordered_map *a = nullptr; - if (value) { - p = (primitive_type *)value->_primitive; - a = value->_args; - } else { - p = new primitive_type(create_primitive_desc( - std::forward(args)...)); - a = new std::unordered_map(); - } - return {buffer, {p, a}}; -} - -template -typename primitive_type::primitive_desc -engine_ext::get_primitive_desc(::dnnl::primitive *p) { - return typename primitive_type::primitive_desc( - const_cast(p->get_primitive_desc())); -} - -template -typename primitive_type::primitive_desc -engine_ext::create_primitive_desc(args_type &&...args) { - return typename primitive_type::primitive_desc( - *_eng, std::forward(args)...); -} - -inline -void engine_ext::fill(const memory_desc_ext &src_desc, void *src, - const void *valuePtr) { - async_fill(src_desc, src, valuePtr).wait(); -} - -inline -void engine_ext::reorder(float alpha, const memory_desc_ext &src_desc, - void *src, float beta, const memory_desc_ext &dst_desc, - void *dst) { - async_reorder(alpha, src_desc, src, beta, dst_desc, dst).wait(); -} - -inline -void engine_ext::scale(float alpha, const memory_desc_ext &src_desc, - void *src) { - async_scale(alpha, src_desc, src).wait(); -} -inline -void engine_ext::sum(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst) { - async_sum(alpha, src_desc, src, beta, dst_desc, dst).wait(); -} -inline -void engine_ext::activation_forward(activation_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst) { - async_activation_forward(desc, alpha, src_desc, src, beta, dst_desc, dst) - .wait(); -} -inline -void engine_ext::activation_backward( - activation_desc &desc, float alpha, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src) { - async_activation_backward(desc, alpha, dst_desc, dst, diff_dst_desc, diff_dst, - src_desc, src, beta, diff_src_desc, diff_src) - .wait(); -} -inline -void engine_ext::pooling_forward(pooling_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, - ::dnnl::memory *workspace) { - async_pooling_forward(desc, alpha, src_desc, src, beta, dst_desc, dst, - workspace).wait(); -} - -inline -void engine_ext::pooling_backward( - pooling_desc &desc, float alpha, const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src, - ::dnnl::memory *workspace) { - async_pooling_backward(desc, alpha, dst_desc, dst, diff_dst_desc, diff_dst, - src_desc, src, beta, diff_src_desc, diff_src, - workspace) - .wait(); -} - -inline -void engine_ext::softmax_forward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, void *dst) { - async_softmax_forward(alg, mode, alpha, src_desc, src, beta, dst_desc, dst) - .wait(); -} - -inline -void engine_ext::softmax_backward(softmax_algorithm alg, softmax_mode mode, - float alpha, const memory_desc_ext &dst_desc, - void *dst, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src) { - async_softmax_backward(alg, mode, alpha, dst_desc, dst, diff_dst_desc, - diff_dst, beta, diff_src_desc, diff_src) - .wait(); -} - -inline -void engine_ext::lrn_forward(lrn_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace) { - async_lrn_forward(desc, alpha, src_desc, src, beta, dst_desc, dst, workspace) - .wait(); -} - -inline -void engine_ext::lrn_backward(lrn_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, - void *diff_dst, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &diff_src_desc, - void *diff_src, - ::dnnl::memory *workspace) { - async_lrn_backward(desc, alpha, dst_desc, dst, diff_dst_desc, diff_dst, - src_desc, src, beta, diff_src_desc, diff_src, workspace) - .wait(); -} - -inline -sycl::event engine_ext::async_fill(const memory_desc_ext &src_desc, void *src, - const void *valuePtr) { - ::dnnl::memory::data_type dt = src_desc.get_desc().get_data_type(); - unsigned mem_size = src_desc.get_size(); - switch (dt) { - case ::dnnl::memory::data_type::f32: - return fill_with_type(_q, src, valuePtr, mem_size); - case ::dnnl::memory::data_type::f16: - return fill_with_type(_q, src, valuePtr, mem_size); - case ::dnnl::memory::data_type::s32: - return fill_with_type(_q, src, valuePtr, mem_size); - case ::dnnl::memory::data_type::s8: - return fill_with_type(_q, src, valuePtr, mem_size); - case ::dnnl::memory::data_type::u8: - return fill_with_type(_q, src, valuePtr, mem_size); - default: - throw std::runtime_error("async_fill: unsupported data type."); - } -} - -inline -sycl::event engine_ext::async_reorder(float alpha, const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, void *dst) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - enter_primitive(2 * dst_desc.get_size()); - - auto primitive_args = create_primitive_args_or_get<::dnnl::reorder>( - src_desc.get_desc(), *_eng, dst_desc.get_desc()); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - return exit_primitive(execute_primitive<::dnnl::reorder>( - primitive_args, {{alpha, beta, DNNL_ARG_DST, dst_desc, dst}})); -} - -inline -sycl::event engine_ext::async_scale(float alpha, const memory_desc_ext &src_desc, - void *src) { - if (alpha == 1.f) { - return sycl::event(); - } - size_t cache_size = src_desc.get_size(); - enter_primitive(cache_size); - void *src_cache = allocate(cache_size); - _q->memcpy(src_cache, src, cache_size); - auto primitive_args = create_primitive_args_or_get<::dnnl::eltwise_forward>( - ::dnnl::prop_kind::forward_inference, ::dnnl::algorithm::eltwise_linear, - src_desc.get_desc(), src_desc.get_desc(), alpha, 0.f); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src_cache); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, src_desc.get_desc(), - src); - - return exit_primitive( - execute_primitive<::dnnl::eltwise_forward>(primitive_args)); -} - -inline sycl::event -engine_ext::async_sum(float alpha, const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, void *dst) { - if (alpha == 0.f && beta == 1.f) { - return sycl::event(); - } - size_t cache_size = dst_desc.get_size(); - enter_primitive(cache_size); - void *dst_cache = allocate(dst_desc); - _q->memcpy(dst_cache, dst, cache_size); - - auto primitive_args = create_primitive_args_or_get<::dnnl::sum>( - std::vector{alpha, beta}, - std::vector<::dnnl::memory::desc>{src_desc.get_desc(), - dst_desc.get_desc()}); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - insert_arg(primitive_args.second.args, DNNL_ARG_MULTIPLE_SRC, - src_desc.get_desc(), src); - insert_arg(primitive_args.second.args, DNNL_ARG_MULTIPLE_SRC + 1, - dst_desc.get_desc(), dst_cache); - - return exit_primitive(execute_primitive<::dnnl::sum>(primitive_args)); -} - -inline -sycl::event engine_ext::async_binary(binary_op op, float alpha_0, - const memory_desc_ext &src_desc_0, void *src_0, - float alpha_1, const memory_desc_ext &src_desc_1, - void *src_1, float beta, - const memory_desc_ext &dst_desc, void *dst) { - ::dnnl::algorithm onednn_algorithm; - switch (op) { - case binary_op::max: - onednn_algorithm = ::dnnl::algorithm::binary_max; - break; - case binary_op::min: - onednn_algorithm = ::dnnl::algorithm::binary_min; - break; - case binary_op::add: - onednn_algorithm = ::dnnl::algorithm::binary_add; - break; - case binary_op::sub: - onednn_algorithm = ::dnnl::algorithm::binary_sub; - break; - case binary_op::mul: - onednn_algorithm = ::dnnl::algorithm::binary_mul; - break; - case binary_op::div: - onednn_algorithm = ::dnnl::algorithm::binary_div; - break; - case binary_op::sqrt: - onednn_algorithm = ::dnnl::algorithm::eltwise_sqrt; - break; - case binary_op::neg: - onednn_algorithm = ::dnnl::algorithm::eltwise_linear; - break; - } - size_t src0_cache_size = src_desc_0.get_size(); - size_t src1_cache_size = src_desc_1.get_size(); - size_t dst_cache_size = dst_desc.get_size(); - enter_primitive(2 * src0_cache_size + 2 * src1_cache_size + - 5 * dst_cache_size); - if (onednn_algorithm == ::dnnl::algorithm::eltwise_sqrt || - onednn_algorithm == ::dnnl::algorithm::eltwise_linear) { - void *src_cache = nullptr, *dst_cache = nullptr; - src_cache = allocate(src0_cache_size); - dst_cache = allocate(dst_cache_size); - _q->memcpy(src_cache, src_0, src0_cache_size); - _q->memcpy(dst_cache, dst, dst_cache_size); - async_scale(alpha_0, src_desc_0, src_cache); - async_scale(beta, dst_desc, dst_cache); - - // Let the output = 1 - input to simulate the behavior of neg. - auto primitive_args = create_primitive_args_or_get<::dnnl::eltwise_forward>( - ::dnnl::prop_kind::forward_inference, onednn_algorithm, - src_desc_0.get_desc(), dst_desc.get_desc(), -1.f, 1.f); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc_0.get_desc(), - src_cache); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - execute_primitive<::dnnl::eltwise_forward>( - primitive_args, {{1.f, 0.f, DNNL_ARG_DST, dst_desc, dst}}); - return exit_primitive( - async_sum(1.f, dst_desc, dst_cache, 1.f, dst_desc, dst)); - } - - void *src_0_cache = nullptr, *src_1_cache = nullptr, *dst_cache = nullptr; - - src_0_cache = allocate(src0_cache_size); - src_1_cache = allocate(src1_cache_size); - dst_cache = allocate(dst_cache_size); - - _q->memcpy(src_0_cache, src_0, src0_cache_size); - _q->memcpy(src_1_cache, src_1, src1_cache_size); - _q->memcpy(dst_cache, dst, dst_cache_size); - - async_scale(alpha_0, src_desc_0, src_0_cache); - async_scale(alpha_1, src_desc_1, src_1_cache); - async_scale(beta, dst_desc, dst_cache); - - auto primitive_args = create_primitive_args_or_get<::dnnl::binary>( - onednn_algorithm, src_desc_0.get_desc(), src_desc_1.get_desc(), - dst_desc.get_desc()); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_0, src_desc_0.get_desc(), - src_0_cache); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_1, src_desc_1.get_desc(), - src_1_cache); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - execute_primitive<::dnnl::binary>(primitive_args, - {{1.f, 0.f, DNNL_ARG_DST, dst_desc, dst}}); - return exit_primitive( - async_sum(1.f, dst_desc, dst_cache, 1.f, dst_desc, dst)); -} - -inline -sycl::event engine_ext::async_reduction(reduction_op op, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst) { - if (alpha == 0.f && beta == 1.f) { - return sycl::event(); - } - size_t src_cache_size = src_desc.get_size(); - size_t dst_cache_size = dst_desc.get_size(); - enter_primitive(3 * src_cache_size + 2 * dst_cache_size); - float p = 2.f; - ::dnnl::algorithm onednn_algorithm; - void *cache = nullptr; - switch (op) { - case reduction_op::amax: - cache = allocate(src_cache_size); - activation_desc adesc; - adesc.set_algorithm(::dnnl::algorithm::eltwise_abs); - async_activation_forward(adesc, 1.f, src_desc, src, 0.f, src_desc, cache); - onednn_algorithm = ::dnnl::algorithm::reduction_max; - src = cache; - break; - case reduction_op::max: - onednn_algorithm = ::dnnl::algorithm::reduction_max; - break; - case reduction_op::min: - onednn_algorithm = ::dnnl::algorithm::reduction_min; - break; - case reduction_op::sum: - onednn_algorithm = ::dnnl::algorithm::reduction_sum; - break; - case reduction_op::mean: - onednn_algorithm = ::dnnl::algorithm::reduction_mean; - break; - case reduction_op::mul: - onednn_algorithm = ::dnnl::algorithm::reduction_mul; - break; - case reduction_op::mul_no_zeros: - cache = allocate(src_cache_size); - transform_no_zero(src_desc, src, cache); - onednn_algorithm = ::dnnl::algorithm::reduction_mul; - src = cache; - break; - case reduction_op::norm1: - p = 1.f; - onednn_algorithm = ::dnnl::algorithm::reduction_norm_lp_power_p_sum; - break; - case reduction_op::norm2: - onednn_algorithm = ::dnnl::algorithm::reduction_norm_lp_sum; - break; - } - auto primitive_args = create_primitive_args_or_get<::dnnl::reduction>( - onednn_algorithm, src_desc.get_desc(), dst_desc.get_desc(), p, 0.f); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - return exit_primitive(execute_primitive<::dnnl::reduction>( - primitive_args, {{alpha, beta, DNNL_ARG_DST, dst_desc, dst}})); -} - -inline -sycl::event engine_ext::async_activation_forward(activation_desc &desc, float alpha, - const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, - void *dst) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - enter_primitive(2 * dst_desc.get_size()); - auto primitive_args = create_primitive_args_or_get<::dnnl::eltwise_forward>( - ::dnnl::prop_kind::forward, desc.get_algorithm(), src_desc.get_desc(), - dst_desc.get_desc(), desc.get_alpha(), desc.get_beta()); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - return exit_primitive(execute_primitive<::dnnl::eltwise_forward>( - primitive_args, {{alpha, beta, DNNL_ARG_DST, dst_desc, dst}})); -} - -inline -sycl::event engine_ext::async_activation_backward( - activation_desc &desc, float alpha, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src) { - - if (scale_parameter_preprocess({{alpha, beta, diff_src_desc, diff_src}})) { - return sycl::event(); - } - enter_primitive(2 * diff_src_desc.get_size()); - ::dnnl::memory::desc data_desc = dst_desc.get_desc(); - auto alg = desc.get_algorithm(); - if ((alg == ::dnnl::algorithm::eltwise_clip) || - (alg == ::dnnl::algorithm::eltwise_linear) || - (alg == ::dnnl::algorithm::eltwise_swish)) { - data_desc = src_desc.get_desc(); - } - auto primitive_args = create_primitive_args_or_get<::dnnl::eltwise_backward>( - alg, diff_src_desc.get_desc(), diff_dst_desc.get_desc(), data_desc, - desc.get_alpha(), desc.get_beta(), - create_primitive_desc<::dnnl::eltwise_forward>( - ::dnnl::prop_kind::forward, alg, src_desc.get_desc(), - dst_desc.get_desc(), desc.get_alpha(), desc.get_beta())); - - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, - diff_dst_desc.get_desc(), diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, - diff_src_desc.get_desc(), diff_src); - - return exit_primitive(execute_primitive<::dnnl::eltwise_backward>( - primitive_args, - {{alpha, beta, DNNL_ARG_DIFF_SRC, diff_src_desc, diff_src}})); -} - -inline -sycl::event engine_ext::async_pooling_forward(pooling_desc &desc, float alpha, - const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - enter_primitive(2 * dst_desc.get_size()); - int pooling_dim = desc.get_stride().size(); - std::vector dilation(pooling_dim, 0); - auto primitive_args = - create_primitive_args_or_get<::dnnl::pooling_forward>( - ::dnnl::prop_kind::forward_training, desc.get_algorithm(), - src_desc.get_desc(), dst_desc.get_desc(), desc.get_stride(), - desc.get_kernel(), dilation, desc.get_padding(), desc.get_padding()); - auto pd = get_primitive_desc<::dnnl::pooling_forward>( - primitive_args.second.primitive); - ::dnnl::memory ws_mem(pd.workspace_desc(), *_eng); - if (workspace) { - *workspace = ws_mem; - } else { - insert_workspace(src, ws_mem); - } - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, ws_mem); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - return exit_primitive(execute_primitive<::dnnl::pooling_forward>( - primitive_args, {{alpha, beta, DNNL_ARG_DST, dst_desc, dst}})); -} - -inline -sycl::event engine_ext::async_pooling_backward( - pooling_desc &desc, float alpha, const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src, - ::dnnl::memory *workspace) { - if (scale_parameter_preprocess({{alpha, beta, diff_src_desc, diff_src}})) { - return sycl::event(); - } - enter_primitive(2 * diff_src_desc.get_size()); - int pooling_dim = desc.get_stride().size(); - std::vector dilation(pooling_dim, 0); - auto primitive_args = create_primitive_args_or_get<::dnnl::pooling_backward>( - desc.get_algorithm(), diff_src_desc.get_desc(), diff_dst_desc.get_desc(), - desc.get_stride(), desc.get_kernel(), dilation, desc.get_padding(), - desc.get_padding(), - create_primitive_desc<::dnnl::pooling_forward>( - ::dnnl::prop_kind::forward_training, desc.get_algorithm(), - src_desc.get_desc(), dst_desc.get_desc(), desc.get_stride(), - desc.get_kernel(), dilation, desc.get_padding(), desc.get_padding())); - - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, - diff_dst_desc.get_desc(), diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, - diff_src_desc.get_desc(), diff_src); - - if (workspace) { - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, *workspace); - } else { - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, - get_workspace(src)); - } - - return exit_primitive(execute_primitive<::dnnl::pooling_backward>( - primitive_args, - {{alpha, beta, DNNL_ARG_DIFF_SRC, diff_src_desc, diff_src}})); -} - -inline -sycl::event engine_ext::async_softmax_forward(softmax_algorithm alg, - softmax_mode mode, float alpha, - const memory_desc_ext &src_desc, - void *src, float beta, - const memory_desc_ext &dst_desc, - void *dst) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - - ::dnnl::memory::desc help_src_desc = src_desc.get_desc(); - ::dnnl::memory::desc help_dst_desc = dst_desc.get_desc(); - if (mode == softmax_mode::instance) { - help_src_desc = compress_spatial_dimensions_to_channel(help_src_desc); - help_dst_desc = compress_spatial_dimensions_to_channel(help_dst_desc); - } - enter_primitive(2 * help_dst_desc.get_size()); - - ::dnnl::algorithm softmax_alg = ::dnnl::algorithm::softmax_accurate; - if (alg == softmax_algorithm::log) { - softmax_alg = ::dnnl::algorithm::softmax_log; - } - auto primitive_args = create_primitive_args_or_get<::dnnl::softmax_forward>( - ::dnnl::prop_kind::forward, softmax_alg, help_src_desc, - help_dst_desc, 1); - - insert_arg(primitive_args.second.args, DNNL_ARG_DST, help_dst_desc, dst); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, help_src_desc, src); - - return exit_primitive(execute_primitive<::dnnl::softmax_forward>( - primitive_args, - {{alpha, beta, DNNL_ARG_DST, memory_desc_ext(help_dst_desc), dst}})); -} - -inline -sycl::event engine_ext::async_softmax_backward( - softmax_algorithm alg, softmax_mode mode, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src) { - if (scale_parameter_preprocess({{alpha, beta, diff_src_desc, diff_src}})) { - return sycl::event(); - } - ::dnnl::memory::desc help_diff_src_desc = diff_src_desc.get_desc(); - ::dnnl::memory::desc help_dst_desc = dst_desc.get_desc(); - ::dnnl::memory::desc help_diff_dst_desc = diff_dst_desc.get_desc(); - if (mode == softmax_mode::instance) { - help_diff_src_desc = - compress_spatial_dimensions_to_channel(help_diff_src_desc); - help_dst_desc = compress_spatial_dimensions_to_channel(help_dst_desc); - help_diff_dst_desc = - compress_spatial_dimensions_to_channel(help_diff_dst_desc); - } - enter_primitive(2 * help_diff_src_desc.get_size()); - - ::dnnl::algorithm softmax_alg = ::dnnl::algorithm::softmax_accurate; - if (alg == softmax_algorithm::log) { - softmax_alg = ::dnnl::algorithm::softmax_log; - } - - auto primitive_args = create_primitive_args_or_get<::dnnl::softmax_backward>( - softmax_alg, help_diff_src_desc, help_diff_dst_desc, help_dst_desc, 1, - create_primitive_desc<::dnnl::softmax_forward>( - ::dnnl::prop_kind::forward, softmax_alg, help_diff_src_desc, - help_dst_desc, 1)); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, help_dst_desc, dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, help_diff_dst_desc, - diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, help_diff_src_desc, - diff_src); - - return exit_primitive(execute_primitive<::dnnl::softmax_backward>( - primitive_args, {{alpha, beta, DNNL_ARG_DIFF_SRC, - memory_desc_ext(help_diff_src_desc), diff_src}})); -} - -inline -sycl::event engine_ext::async_lrn_forward(lrn_desc &desc, float alpha, - const memory_desc_ext &src_desc, void *src, - float beta, const memory_desc_ext &dst_desc, - void *dst, ::dnnl::memory *workspace) { - - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - enter_primitive(2 * dst_desc.get_size()); - auto primitive_args = create_primitive_args_or_get<::dnnl::lrn_forward>( - ::dnnl::prop_kind::forward_training, - ::dnnl::algorithm::lrn_across_channels, src_desc.get_desc(), - dst_desc.get_desc(), desc.get_local_size(), desc.get_alpha(), - desc.get_beta(), desc.get_k()); - auto pd = - get_primitive_desc<::dnnl::lrn_forward>(primitive_args.second.primitive); - ::dnnl::memory ws_mem(pd.workspace_desc(), *_eng); - if (workspace) { - *workspace = ws_mem; - } else { - insert_workspace(src, ws_mem); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, ws_mem); - - return exit_primitive(execute_primitive<::dnnl::lrn_forward>( - primitive_args, {{alpha, beta, DNNL_ARG_DST, dst_desc, dst}})); -} - -inline -sycl::event -engine_ext::async_lrn_backward(lrn_desc &desc, float alpha, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &diff_dst_desc, void *diff_dst, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src, - ::dnnl::memory *workspace) { - - if (scale_parameter_preprocess({{alpha, beta, diff_src_desc, diff_src}})) { - return sycl::event(); - } - enter_primitive(2 * diff_src_desc.get_size()); - auto primitive_args = create_primitive_args_or_get<::dnnl::lrn_backward>( - ::dnnl::algorithm::lrn_across_channels, diff_src_desc.get_desc(), - diff_dst_desc.get_desc(), src_desc.get_desc(), desc.get_local_size(), - desc.get_alpha(), desc.get_beta(), desc.get_k(), - create_primitive_desc<::dnnl::lrn_forward>( - ::dnnl::prop_kind::forward_training, - ::dnnl::algorithm::lrn_across_channels, src_desc.get_desc(), - dst_desc.get_desc(), desc.get_local_size(), desc.get_alpha(), - desc.get_beta(), desc.get_k())); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, - diff_dst_desc.get_desc(), diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, - diff_src_desc.get_desc(), diff_src); - - if (workspace) { - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, *workspace); - } else { - insert_arg(primitive_args.second.args, DNNL_ARG_WORKSPACE, - get_workspace(src)); - } - - return exit_primitive(execute_primitive<::dnnl::lrn_backward>( - primitive_args, - {{alpha, beta, DNNL_ARG_DIFF_SRC, diff_src_desc, diff_src}})); -} - -inline -size_t engine_ext::get_batch_normalization_workspace_size( - batch_normalization_ops ops, const memory_desc_ext &src_desc) { - if(ops == batch_normalization_ops::none) { - return 0; - } - return src_desc.get_size(); -} - -inline -sycl::event engine_ext::async_batch_normalization_forward_inference( - batch_normalization_mode mode, float epsilon, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *mean, void *var) { - - return batch_normalization_forward_internal( - true, mode, epsilon, 0.f, alpha, src_desc, src, beta, dst_desc, dst, - scale_bias_mean_var_desc, scale, bias, scale_bias_mean_var_desc, mean, - var, nullptr, nullptr); -} - -inline -sycl::event engine_ext::async_batch_normalization_forward_inference( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *mean, void *var) { - - bool has_post_op = (ops != batch_normalization_ops::none); - sycl::event e; - enter_primitive(src_desc.get_size() + dst_desc.get_size() * 4 + - scale_bias_desc.get_size() * 2 + - mean_var_desc.get_size() * 5); - if (has_post_op) { - void *dst_cache = allocate(dst_desc); - batch_normalization_forward_internal( - true, mode, epsilon, 0.f, 1.f, src_desc, src, 0.f, dst_desc, dst_cache, - scale_bias_desc, scale, bias, mean_var_desc, mean, var, nullptr, - nullptr); - - if (ops == batch_normalization_ops::add_activation) { - async_sum(1.f, summand_desc, summand, 1.f, dst_desc, dst_cache); - } - async_activation_forward(adesc, 1.f, dst_desc, dst_cache, 0.f, dst_desc, - dst_cache); - return exit_primitive( - async_sum(alpha, dst_desc, dst_cache, beta, dst_desc, dst)); - } - return exit_primitive(batch_normalization_forward_internal( - true, mode, epsilon, 0.f, alpha, src_desc, src, beta, dst_desc, dst, - scale_bias_desc, scale, bias, mean_var_desc, mean, var, nullptr, - nullptr)); -} - -inline -sycl::event engine_ext::async_batch_normalization_forward_training( - batch_normalization_mode mode, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *running_mean, void *running_var, void *saved_mean, void *saved_var) { - return batch_normalization_forward_internal( - false, mode, epsilon, factor, alpha, src_desc, src, beta, dst_desc, dst, - scale_bias_mean_var_desc, scale, bias, scale_bias_mean_var_desc, - saved_mean, saved_var, running_mean, running_var); -} - -inline -sycl::event engine_ext::async_batch_normalization_forward_training( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_desc, void *scale, void *bias, - const memory_desc_ext &mean_var_desc, void *running_mean, void *running_var, - void *saved_mean, void *saved_var, size_t workspace_size, - void *workspace) { - enter_primitive(src_desc.get_size() + dst_desc.get_size() * 3 + - mean_var_desc.get_size() * 5 + - scale_bias_desc.get_size() * 2); - bool has_post_op = (ops != batch_normalization_ops::none); - sycl::event e; - if (has_post_op) { - if(workspace_size < dst_desc.get_desc().get_size()) { - throw std::runtime_error("async_batch_normalization_forward_training_ex: " - "no sufficient workspace."); - } - batch_normalization_forward_internal( - false, mode, epsilon, factor, 1.f, src_desc, src, 0.f, dst_desc, - workspace, scale_bias_desc, scale, bias, mean_var_desc, - saved_mean, saved_var, running_mean, running_var); - if (ops == batch_normalization_ops::add_activation) { - async_sum(1.f, summand_desc, summand, 1.f, dst_desc, - workspace); - } - return exit_primitive(async_activation_forward( - adesc, alpha, dst_desc, workspace, beta, dst_desc, dst)); - } - return exit_primitive(batch_normalization_forward_internal( - false, mode, epsilon, factor, alpha, src_desc, src, beta, dst_desc, dst, - scale_bias_desc, scale, bias, mean_var_desc, saved_mean, saved_var, - running_mean, running_var)); -} - -inline -sycl::event engine_ext::async_batch_normalization_forward_training( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float factor, float alpha, - const memory_desc_ext &src_desc, void *src, float beta, - const memory_desc_ext &dst_desc, void *dst, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &scale_bias_mean_var_desc, void *scale, void *bias, - void *running_mean, void *running_var, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace) { - return async_batch_normalization_forward_training( - mode, ops, adesc, epsilon, factor, alpha, src_desc, src, beta, dst_desc, - dst, summand_desc, summand, scale_bias_mean_var_desc, scale, bias, - scale_bias_mean_var_desc, running_mean, running_var, saved_mean, - saved_var, workspace_size, workspace); -} - -inline -sycl::event engine_ext::async_batch_normalization_backward( - batch_normalization_mode mode, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta_data, - const memory_desc_ext &diff_src_desc, void *diff_src, float alpha_param, - const memory_desc_ext &diff_scale_bias_mean_var_desc, void *scale, - float beta_param, void *diff_scale, void *diff_bias, void *saved_mean, - void *saved_var) { - - return batch_normalization_backward_internal( - mode, epsilon, alpha_data, src_desc, src, diff_dst_desc, diff_dst, - beta_data, diff_src_desc, diff_src, alpha_param, - diff_scale_bias_mean_var_desc, scale, nullptr, beta_param, diff_scale, - diff_bias, diff_scale_bias_mean_var_desc, saved_mean, saved_var); -} - -inline -sycl::event engine_ext::async_batch_normalization_backward( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, void *diff_dst, - float beta_data, const memory_desc_ext &diff_src_desc, void *diff_src, - const memory_desc_ext &diff_summand_desc, void *diff_summand, - float alpha_param, const memory_desc_ext &diff_scale_bias_desc, void *scale, - void *bias, float beta_param, void *diff_scale, void *diff_bias, - const memory_desc_ext &mean_var_desc, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace) { - std::vector caches; - ::dnnl::memory::desc real_diff_dst_desc = diff_dst_desc.get_desc(); - void *real_diff_dst = diff_dst; - - if (ops != batch_normalization_ops::none && - workspace_size < dst_desc.get_desc().get_size()) { - throw std::runtime_error("async_batch_normalization_backward_ex: " - "no sufficient workspace."); - } - enter_primitive(diff_scale_bias_desc.get_size() * 8 + - src_desc.get_size() * 3 + diff_dst_desc.get_size() * 5 + - diff_src_desc.get_size() + mean_var_desc.get_size() * 9 + - diff_summand_desc.get_size()); - if (ops == batch_normalization_ops::add_activation) { - void *diff_summand_cache = allocate(diff_summand_desc); - async_activation_backward(adesc, 1.f, dst_desc, dst, diff_dst_desc, diff_dst, - dst_desc, workspace, 0.f, - diff_summand_desc, diff_summand_cache); - async_sum(alpha_data, diff_summand_desc, diff_summand_cache, beta_data, - diff_summand_desc, diff_summand); - real_diff_dst_desc = diff_summand_desc.get_desc(); - real_diff_dst = diff_summand_cache; - } else if (ops == batch_normalization_ops::activation) { - void *diff_dst_cache = allocate(diff_dst_desc); - async_activation_backward(adesc, 1.f, dst_desc, dst, diff_dst_desc, - diff_dst, dst_desc, workspace, - 0.f, diff_dst_desc, diff_dst_cache); - real_diff_dst = diff_dst_cache; - } - - return exit_primitive(batch_normalization_backward_internal( - mode, epsilon, alpha_data, src_desc, src, real_diff_dst_desc, - real_diff_dst, beta_data, diff_src_desc, diff_src, alpha_param, - diff_scale_bias_desc, scale, bias, beta_param, diff_scale, diff_bias, - mean_var_desc, saved_mean, saved_var)); -} - -inline -sycl::event engine_ext::async_batch_normalization_backward( - batch_normalization_mode mode, batch_normalization_ops ops, - activation_desc &adesc, float epsilon, float alpha_data, - const memory_desc_ext &src_desc, void *src, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &diff_dst_desc, void *diff_dst, - float beta_data, const memory_desc_ext &diff_src_desc, void *diff_src, - const memory_desc_ext &diff_summand_desc, void *diff_summand, - float alpha_param, const memory_desc_ext &diff_scale_bias_mean_var_desc, - void *scale, void *bias, float beta_param, void *diff_scale, - void *diff_bias, void *saved_mean, void *saved_var, - size_t workspace_size, void *workspace) { - - return async_batch_normalization_backward( - mode, ops, adesc, epsilon, alpha_data, src_desc, src, dst_desc, dst, - diff_dst_desc, diff_dst, beta_data, diff_src_desc, diff_src, - diff_summand_desc, diff_summand, alpha_param, - diff_scale_bias_mean_var_desc, scale, bias, beta_param, diff_scale, - diff_bias, diff_scale_bias_mean_var_desc, saved_mean, saved_var, - workspace_size, workspace); -} - -inline -sycl::event -engine_ext::async_convolution_forward(convolution_desc &desc, ::dnnl::algorithm alg, - float alpha, const memory_desc_ext &src_desc, - void *src, const memory_desc_ext &weight_desc, - void *weight, float beta, - const memory_desc_ext &dst_desc, void *dst) { - if (scale_parameter_preprocess({{alpha, beta, dst_desc, dst}})) { - return sycl::event(); - } - auto help_weight_desc = - get_group_weight_desc(desc.get_group_count(), weight_desc); - - ::dnnl::primitive_attr attr; - attr.set_fpmath_mode(desc.get_math_mode()); - - auto origin_src_md = src_desc.get_desc(); - auto origin_dst_md = dst_desc.get_desc(); - auto origin_weight_md = help_weight_desc; - auto src_md = transfer_memory_desc_to_format_tag_any(origin_src_md); - auto dst_md = transfer_memory_desc_to_format_tag_any(origin_dst_md); - auto weight_md = transfer_memory_desc_to_format_tag_any(origin_weight_md); - - auto primitive_args = - create_primitive_args_or_get<::dnnl::convolution_forward>( - ::dnnl::prop_kind::forward_training, alg, src_md, weight_md, dst_md, - desc.get_stride(), desc.get_dilate(), desc.get_padding(), - desc.get_padding(), attr); - - auto pd = get_primitive_desc<::dnnl::convolution_forward>( - primitive_args.second.primitive); - auto optimal_src_md = pd.src_desc(); - auto optimal_dst_md = pd.dst_desc(); - auto optimal_weight_md = pd.weights_desc(); - - enter_primitive( - optimal_src_md.get_size() * 3 + optimal_dst_md.get_size() * 5 + - optimal_weight_md.get_size() * 3 + origin_dst_md.get_size() * 2); - - void *optimal_src = src, *optimal_dst = dst, *optimal_weight = weight; - allocate_and_reorder_memory_to_optimal(origin_src_md, src, optimal_src_md, - optimal_src); - allocate_and_reorder_memory_to_optimal(origin_weight_md, weight, - optimal_weight_md, optimal_weight); - - if (beta == 0.f) { - if(origin_dst_md != optimal_dst_md) { - optimal_dst = allocate(optimal_dst_md); - } - } else { - allocate_and_reorder_memory_to_optimal(origin_dst_md, dst, optimal_dst_md, - optimal_dst); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, optimal_src_md, - optimal_src); - insert_arg(primitive_args.second.args, DNNL_ARG_WEIGHTS, optimal_weight_md, - optimal_weight); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, optimal_dst_md, - optimal_dst); - - auto e = execute_primitive<::dnnl::convolution_forward>( - primitive_args, - {{alpha, beta, DNNL_ARG_DST, optimal_dst_md, optimal_dst}}); - - if (origin_dst_md != optimal_dst_md) { - e = async_reorder(1.f, optimal_dst_md, optimal_dst, 0.f, origin_dst_md, - dst); - } - return exit_primitive(e); -} - -inline -sycl::event engine_ext::async_convolution_forward( - convolution_desc &desc, ::dnnl::algorithm alg, activation_desc &adesc, - float alpha_0, const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &weight_desc, void *weight, float alpha_1, - const memory_desc_ext &summand_desc, void *summand, - const memory_desc_ext &bias_desc, void *bias, - const memory_desc_ext &dst_desc, void *dst) { - - int channel_num = bias_desc.get_element_num(); - auto help_weight_desc = - get_group_weight_desc(desc.get_group_count(), weight_desc); - ::dnnl::memory::desc help_bias_desc = {{channel_num}, - bias_desc.get_desc().get_data_type(), - ::dnnl::memory::format_tag::a}; - auto origin_weight_md = help_weight_desc; - auto origin_bias_md = help_bias_desc; - auto origin_src_md = src_desc.get_desc(); - auto origin_dst_md = dst_desc.get_desc(); - auto src_md = transfer_memory_desc_to_format_tag_any(origin_src_md); - auto dst_md = transfer_memory_desc_to_format_tag_any(origin_dst_md); - auto weight_md = transfer_memory_desc_to_format_tag_any(origin_weight_md); - auto bias_md = transfer_memory_desc_to_format_tag_any(origin_bias_md); - - ::dnnl::primitive_attr attr; - attr.set_fpmath_mode(desc.get_math_mode()); - - auto primitive_args = - create_primitive_args_or_get<::dnnl::convolution_forward>( - ::dnnl::prop_kind::forward_training, alg, src_md, weight_md, bias_md, - dst_md, desc.get_stride(), desc.get_dilate(), desc.get_padding(), - desc.get_padding(), attr); - - auto pd = get_primitive_desc<::dnnl::convolution_forward>( - primitive_args.second.primitive); - auto optimal_src_md = pd.src_desc(); - auto optimal_dst_md = pd.dst_desc(); - auto optimal_weight_md = pd.weights_desc(); - auto optimal_bias_md = pd.bias_desc(); - - enter_primitive(optimal_src_md.get_size() + 3 * optimal_weight_md.get_size() + - optimal_bias_md.get_size() + 7 * optimal_dst_md.get_size() + - summand_desc.get_size()); - - void *optimal_src = src, *optimal_dst = dst, *optimal_weight = weight, - *optimal_bias = bias; - allocate_and_reorder_memory_to_optimal(origin_src_md, src, optimal_src_md, - optimal_src); - allocate_and_reorder_memory_to_optimal(origin_weight_md, weight, - optimal_weight_md, optimal_weight); - allocate_and_reorder_memory_to_optimal(origin_bias_md, bias, optimal_bias_md, - optimal_bias); - if (origin_dst_md != optimal_dst_md) { - optimal_dst = allocate(optimal_dst_md); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, optimal_src_md, - optimal_src); - insert_arg(primitive_args.second.args, DNNL_ARG_BIAS, optimal_bias_md, - optimal_bias); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, optimal_dst_md, - optimal_dst); - - void *cache = nullptr; - if (alpha_0 != 1.f) { - cache = allocate(optimal_weight_md); - _q->memcpy(cache, optimal_weight, optimal_weight_md.get_size()); - async_scale(alpha_0, optimal_weight_md, cache); - insert_arg(primitive_args.second.args, DNNL_ARG_WEIGHTS, optimal_weight_md, - cache); - execute_primitive<::dnnl::convolution_forward>( - primitive_args, - {{1.f, 0.f, DNNL_ARG_DST, optimal_dst_md, optimal_dst}}); - } else { - insert_arg(primitive_args.second.args, DNNL_ARG_WEIGHTS, optimal_weight_md, - optimal_weight); - execute_primitive<::dnnl::convolution_forward>( - primitive_args, - {{1.f, 0.f, DNNL_ARG_DST, optimal_dst_md, optimal_dst}}); - } - if (origin_dst_md != optimal_dst_md) { - async_reorder(1.f, optimal_dst_md, optimal_dst, 0.f, origin_dst_md, dst); - } - async_sum(alpha_1, summand_desc, summand, 1.f, dst_desc, dst); - return exit_primitive( - async_activation_forward(adesc, 1.f, dst_desc, dst, 0.f, dst_desc, dst)); -} - -inline -sycl::event engine_ext::async_convolution_backward_data( - convolution_desc &desc, ::dnnl::algorithm alg, float alpha, - const memory_desc_ext &weight_desc, void *weight, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta, - const memory_desc_ext &diff_src_desc, void *diff_src) { - - if (scale_parameter_preprocess({{alpha, beta, diff_dst_desc, diff_dst}})) { - return sycl::event(); - } - - auto help_weight_desc = - get_group_weight_desc(desc.get_group_count(), weight_desc); - - auto origin_weight_md = help_weight_desc; - auto origin_diff_src_md = diff_src_desc.get_desc(); - auto origin_diff_dst_md = diff_dst_desc.get_desc(); - auto diff_src_md = transfer_memory_desc_to_format_tag_any(origin_diff_src_md); - auto diff_dst_md = transfer_memory_desc_to_format_tag_any(origin_diff_dst_md); - auto weight_md = transfer_memory_desc_to_format_tag_any(origin_weight_md); - - ::dnnl::primitive_attr attr; - attr.set_fpmath_mode(desc.get_math_mode()); - - auto forward_primitive = create_primitive_desc<::dnnl::convolution_forward>( - ::dnnl::prop_kind::forward_training, ::dnnl::algorithm::convolution_auto, - diff_src_md, weight_md, diff_dst_md, desc.get_stride(), desc.get_dilate(), - desc.get_padding(), desc.get_padding(), attr); - - auto primitive_args = - create_primitive_args_or_get<::dnnl::convolution_backward_data>( - ::dnnl::algorithm::convolution_auto, diff_src_md, weight_md, - diff_dst_md, desc.get_stride(), desc.get_dilate(), desc.get_padding(), - desc.get_padding(), forward_primitive, attr); - - auto pd = get_primitive_desc<::dnnl::convolution_backward_data>( - primitive_args.second.primitive); - auto optimal_diff_src_md = pd.diff_src_desc(); - auto optimal_diff_dst_md = pd.diff_dst_desc(); - auto optimal_weight_md = pd.weights_desc(); - - enter_primitive(5 * optimal_diff_src_md.get_size() + - optimal_diff_dst_md.get_size() + - optimal_weight_md.get_size()); - - void *optimal_diff_src = diff_src, *optimal_diff_dst = diff_dst, - *optimal_weight = weight; - allocate_and_reorder_memory_to_optimal(origin_diff_dst_md, diff_dst, - optimal_diff_dst_md, optimal_diff_dst); - allocate_and_reorder_memory_to_optimal(origin_weight_md, weight, - optimal_weight_md, optimal_weight); - if (beta == 0.f) { - if (origin_diff_src_md != optimal_diff_src_md) { - optimal_diff_src = allocate(optimal_diff_src_md); - } - } else { - allocate_and_reorder_memory_to_optimal( - origin_diff_src_md, diff_src, optimal_diff_src_md, optimal_diff_src); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, optimal_diff_dst_md, - optimal_diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_WEIGHTS, optimal_weight_md, - optimal_weight); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_SRC, optimal_diff_src_md, - optimal_diff_src); - - auto e = execute_primitive<::dnnl::convolution_backward_data>( - primitive_args, - {{alpha, beta, DNNL_ARG_DIFF_SRC, optimal_diff_src_md, optimal_diff_src}}); - - if (origin_diff_src_md != optimal_diff_src_md) { - e = async_reorder(1.f, optimal_diff_src_md, optimal_diff_src, 0.f, - origin_diff_src_md, diff_src); - } - return exit_primitive(e); -} - -inline -sycl::event engine_ext::async_convolution_backward_weight( - convolution_desc &desc, ::dnnl::algorithm alg, float alpha, - const memory_desc_ext &src_desc, void *src, - const memory_desc_ext &diff_dst_desc, void *diff_dst, float beta, - const memory_desc_ext &diff_weight_desc, void *diff_weight) { - - if (scale_parameter_preprocess( - {{alpha, beta, diff_weight_desc, diff_weight}})) { - return sycl::event(); - } - - auto help_diff_weight_desc = - get_group_weight_desc(desc.get_group_count(), diff_weight_desc); - - ::dnnl::primitive_attr attr; - attr.set_fpmath_mode(desc.get_math_mode()); - - auto origin_diff_weight_md = help_diff_weight_desc; - auto origin_src_md = src_desc.get_desc(); - auto origin_diff_dst_md = diff_dst_desc.get_desc(); - auto src_md = transfer_memory_desc_to_format_tag_any(origin_src_md); - auto diff_dst_md = transfer_memory_desc_to_format_tag_any(origin_diff_dst_md); - auto diff_weight_md = - transfer_memory_desc_to_format_tag_any(origin_diff_weight_md); - - auto forward_primitive = create_primitive_desc<::dnnl::convolution_forward>( - ::dnnl::prop_kind::forward_training, ::dnnl::algorithm::convolution_auto, - src_md, diff_weight_md, diff_dst_md, desc.get_stride(), desc.get_dilate(), - desc.get_padding(), desc.get_padding(), attr); - - auto primitive_args = - create_primitive_args_or_get<::dnnl::convolution_backward_weights>( - ::dnnl::algorithm::convolution_auto, src_md, diff_weight_md, - diff_dst_md, desc.get_stride(), desc.get_dilate(), desc.get_padding(), - desc.get_padding(), forward_primitive, attr); - - auto pd = get_primitive_desc<::dnnl::convolution_backward_weights>( - primitive_args.second.primitive); - auto optimal_src_md = pd.src_desc(); - auto optimal_diff_dst_md = pd.diff_dst_desc(); - auto optimal_diff_weight_md = pd.diff_weights_desc(); - - enter_primitive(optimal_diff_weight_md.get_size() * 5 + - optimal_diff_dst_md.get_size() + optimal_src_md.get_size()); - - void *optimal_src = src, *optimal_diff_dst = diff_dst, - *optimal_diff_weight = diff_weight; - allocate_and_reorder_memory_to_optimal(origin_diff_dst_md, diff_dst, - optimal_diff_dst_md, optimal_diff_dst); - allocate_and_reorder_memory_to_optimal(origin_src_md, src, optimal_src_md, - optimal_src); - if (beta == 0.f) { - if (origin_diff_weight_md != optimal_diff_weight_md) { - optimal_diff_weight = allocate(optimal_diff_weight_md); - } - } else { - allocate_and_reorder_memory_to_optimal(origin_diff_weight_md, diff_weight, - optimal_diff_weight_md, - optimal_diff_weight); - } - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC, optimal_src_md, - optimal_src); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_DST, optimal_diff_dst_md, - optimal_diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_DIFF_WEIGHTS, - optimal_diff_weight_md, optimal_diff_weight); - - auto e = execute_primitive<::dnnl::convolution_backward_weights>( - primitive_args, {{alpha, beta, DNNL_ARG_DIFF_WEIGHTS, - optimal_diff_weight_md, optimal_diff_weight}}); - - if (origin_diff_weight_md != optimal_diff_weight_md) { - e = async_reorder(1.f, optimal_diff_weight_md, optimal_diff_weight, 0.f, - origin_diff_weight_md, diff_weight); - } - return exit_primitive(e); -} - -inline -sycl::event engine_ext::async_convolution_backward_bias( - float alpha, const memory_desc_ext &diff_dst_desc, void *diff_dst, - float beta, const memory_desc_ext &diff_bias_desc, void *diff_bias) { - return async_reduction(reduction_op::sum, alpha, diff_dst_desc, diff_dst, beta, - diff_bias_desc, diff_bias); -} - -inline -void engine_ext::rnn_get_weight_space_size(const rnn_desc &desc, - size_t *weight_space_size) { - *weight_space_size = 0; - rnn_forward_internal(desc, ::dnnl::prop_kind::forward_inference, - memory_desc_ext(), nullptr, memory_desc_ext(), nullptr, - memory_desc_ext(), nullptr, nullptr, memory_desc_ext(), - nullptr, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, true, - weight_space_size, nullptr, nullptr); - return; -} - -inline -void engine_ext::rnn_get_scratchpad_workspace_size( - const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, size_t *scratchpad_size, - size_t *workspace_size) { - *workspace_size = 0; - *scratchpad_size = 0; - rnn_forward_internal(desc, kind, src_desc, nullptr, memory_desc_ext(), - nullptr, memory_desc_ext(), nullptr, nullptr, - memory_desc_ext(), nullptr, nullptr, 0, nullptr, 0, - nullptr, 0, nullptr, true, nullptr, workspace_size, - scratchpad_size); - return; -} - -inline -sycl::event engine_ext::async_rnn_forward( - const rnn_desc &desc, ::dnnl::prop_kind kind, - const memory_desc_ext &src_desc, void *src, const memory_desc_ext &dst_desc, - void *dst, const memory_desc_ext &iter_desc, void *src_iter, void *dst_iter, - const memory_desc_ext &iter_c_desc, void *src_iter_c, void *dst_iter_c, - size_t weight_size, void *weight, size_t scratchpad_size, void *scratchpad, - size_t workspace_size, void *workspace) { - - return rnn_forward_internal( - desc, kind, src_desc, src, dst_desc, dst, iter_desc, src_iter, dst_iter, - iter_c_desc, src_iter_c, dst_iter_c, weight_size, weight, workspace_size, - workspace, scratchpad_size, scratchpad, false, nullptr, nullptr, - nullptr); -} - -inline -sycl::event engine_ext::async_rnn_backward( - const rnn_desc &desc, const memory_desc_ext &dst_desc, void *dst, - void *diff_dst, const memory_desc_ext &src_desc, void *src, void *diff_src, - const memory_desc_ext &iter_desc, void *src_iter, void *diff_dst_iter, - void *diff_src_iter, const memory_desc_ext &iter_c_desc, void *src_iter_c, - void *diff_dst_iter_c, void *diff_src_iter_c, size_t weight_size, - void *weight, void *diff_weight, size_t scratchpad_size, void *scratchpad, - size_t workspace_size, void *workspace) { - ::dnnl::memory::data_type src_dt; - ::dnnl::memory::format_tag src_format_tag; - rnn_mode mode; - rnn_memory_format_tag format_tag; - rnn_bias_mode bias_mode; - rnn_direction direction; - dpct::library_data_t dt; - int direction_num = 1, input_size = 0, hidden_size = 0, projection_size = 0, - layer_size = 0, gate_num = 1, output_size = 0, data_type_size = 0, - seq_length = 1, batch_size = 1; - void *last_layer_cache = nullptr; - void *hidden_layer_cache = nullptr; - sycl::event e; - enter_primitive(src_desc.get_size() * 2); - std::vector offset(9, 0); - std::vector data = { - src, - dst, - (uint8_t *)src_iter + iter_desc.get_size(), - nullptr, - (uint8_t *)src_iter_c + iter_c_desc.get_size(), - nullptr, - (uint8_t *)weight + weight_size, - (uint8_t *)workspace + workspace_size, - diff_src, - diff_dst, - (uint8_t *)diff_src_iter + iter_desc.get_size(), - (uint8_t *)diff_dst_iter + iter_desc.get_size(), - (uint8_t *)diff_src_iter_c + iter_c_desc.get_size(), - (uint8_t *)diff_dst_iter_c + iter_c_desc.get_size(), - (uint8_t *)diff_weight + weight_size, - scratchpad}; - - desc.get(&mode, &bias_mode, &direction, &dt, &input_size, &hidden_size, - &projection_size, &layer_size); - - get_rnn_configuration(src_desc.get_desc(), direction, mode, dt, hidden_size, - &src_dt, &src_format_tag, &projection_size, - &output_size, &seq_length, &batch_size, &direction_num, - &gate_num); - - if (direction == rnn_direction::bidirectional) { - if (layer_size > 1) { - last_layer_cache = allocate(src_desc); - hidden_layer_cache = allocate(src_desc); - data[8] = last_layer_cache; - } - e = execute_rnn_backward_primitive( - mode, ::dnnl::rnn_direction::bidirectional_concat, bias_mode, src_dt, - src_format_tag, seq_length, batch_size, output_size, 2 * output_size, 1, - direction_num, hidden_size, gate_num, projection_size, data, offset, 1); - if (layer_size > 1) { - data[8] = hidden_layer_cache; - data[9] = last_layer_cache; - e = execute_rnn_backward_primitive( - mode, ::dnnl::rnn_direction::bidirectional_sum, bias_mode, src_dt, - src_format_tag, seq_length, batch_size, output_size, output_size, 1, - direction_num, hidden_size, gate_num, projection_size, data, offset, - layer_size - 1); - _q->memcpy(diff_src, - ((layer_size - 1) % 2 == 0) ? last_layer_cache - : hidden_layer_cache, - src_desc.get_size()); - } - } else { - e = execute_rnn_backward_primitive( - mode, ::dnnl::rnn_direction::unidirectional_left2right, bias_mode, - src_dt, src_format_tag, seq_length, batch_size, output_size, - output_size, layer_size, direction_num, hidden_size, gate_num, - projection_size, data, offset, 1); - } - - return exit_primitive(e); -} - -inline -size_t engine_ext::get_dropout_state_size(){ -#ifndef __INTEL_MKL__ - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) " - "Interfaces Project does not support this API."); -#else - auto r = get_internal_resource(_q); - if(r->random_engine_state_size == -1){ - auto rand_engine = rng_engine_t(*_q, 0); - r->random_engine_state_size = - oneapi::mkl::rng::get_state_size(rand_engine); - } - return r->random_engine_state_size; -#endif -} - -inline size_t -engine_ext::get_dropout_workspace_size(const memory_desc_ext &src_desc) { - return src_desc.get_size(); -} - -inline -sycl::event engine_ext::async_dropout_forward(dropout_desc &desc, - const memory_desc_ext &src_desc, - void *src, - const memory_desc_ext &dst_desc, - void *dst, void *workspace, - size_t workspace_size) { - if (workspace_size < src_desc.get_size()) { - throw std::runtime_error("async_dropout_forward: no sufficient workspace."); - } - enter_primitive(src_desc.get_size() * 2 + dst_desc.get_size() * 2); - float p = desc.get_probability(); - if (p == 1.f) { - return _q->memset(dst, 0, dst_desc.get_size()); - } else if (p == 0.f) { - return async_reorder(1.f, src_desc, src, 0.f, dst_desc, dst); - } - - float scale_factor = 1.f / (1.f - p); - void *cache = workspace; - - memory_desc_ext rng_data_desc( - ::dnnl::memory::desc(src_desc.get_dims(), ::dnnl::memory::data_type::s32, - src_desc.get_strides())); - if (src_desc.get_desc().get_data_type() != ::dnnl::memory::data_type::s32) { - cache = allocate(rng_data_desc); - } - - desc.generate(_q, get_dropout_state_size(), rng_data_desc.get_element_num(), - (std::int32_t *)cache); - - if (cache == workspace) { - async_scale(scale_factor, src_desc, workspace); - } else { - async_reorder(scale_factor, rng_data_desc, cache, 0.f, src_desc, workspace); - } - - auto primitive_args = create_primitive_args_or_get<::dnnl::binary>( - ::dnnl::algorithm::binary_mul, src_desc.get_desc(), src_desc.get_desc(), - dst_desc.get_desc()); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_0, src_desc.get_desc(), - src); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_1, src_desc.get_desc(), - workspace); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, dst_desc.get_desc(), - dst); - - return exit_primitive(execute_primitive<::dnnl::binary>(primitive_args)); -} - -inline -sycl::event engine_ext::async_dropout_backward( - dropout_desc &desc, const memory_desc_ext &diff_dst_desc, - void *diff_dst, const memory_desc_ext &diff_src_desc, void *diff_src, - void *workspace, size_t workspace_size) { - enter_primitive(2 * diff_src_desc.get_size()); - float p = desc.get_probability(); - if (p == 1.f) { - return _q->memset(diff_src, 0, diff_src_desc.get_size()); - } else if (p == 0.f) { - return async_reorder(1.f, diff_dst_desc, diff_dst, 0.f, diff_src_desc, - diff_src); - } - - auto primitive_args = create_primitive_args_or_get<::dnnl::binary>( - ::dnnl::algorithm::binary_mul, diff_dst_desc.get_desc(), - diff_dst_desc.get_desc(), diff_src_desc.get_desc()); - - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_0, - diff_dst_desc.get_desc(), diff_dst); - insert_arg(primitive_args.second.args, DNNL_ARG_SRC_1, - diff_dst_desc.get_desc(), workspace); - insert_arg(primitive_args.second.args, DNNL_ARG_DST, diff_src_desc.get_desc(), - diff_src); - - return exit_primitive(execute_primitive<::dnnl::binary>(primitive_args)); -} -} // namespace dnnl -} // namespace dpct - -#endif // __DPCT_DNNL_UTILS_HPP__ diff --git a/dpct/dpct.hpp b/dpct/dpct.hpp deleted file mode 100644 index 8cc312f0e..000000000 --- a/dpct/dpct.hpp +++ /dev/null @@ -1,62 +0,0 @@ -//==---- dpct.hpp ---------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_HPP__ -#define __DPCT_HPP__ - -#include -#include -#include -#include - -template class dpct_kernel_name; -template class dpct_kernel_scalar; - -#include "atomic.hpp" -#include "device.hpp" -#include "image.hpp" -#include "kernel.hpp" -#include "math.hpp" -#include "memory.hpp" -#include "util.hpp" - -#if defined(_MSC_VER) -#define __dpct_align__(n) __declspec(align(n)) -#define __dpct_inline__ __forceinline -#else -#define __dpct_align__(n) __attribute__((aligned(n))) -#define __dpct_inline__ __inline__ __attribute__((always_inline)) -#endif - -#if defined(_MSC_VER) -#define __dpct_noinline__ __declspec(noinline) -#else -#define __dpct_noinline__ __attribute__((noinline)) -#endif - -#define DPCT_COMPATIBILITY_TEMP (900) - -namespace dpct{ -enum error_code { success = 0, default_error = 999 }; -} - -#define DPCT_CHECK_ERROR(expr) \ - [&]() { \ - try { \ - expr; \ - return dpct::success; \ - } catch (std::exception const &e) { \ - std::cerr << e.what() << std::endl; \ - return dpct::default_error; \ - } \ - }() - -#define DPCT_PI_F (3.14159274101257f) -#define DPCT_PI (3.141592653589793115998) - -#endif // __DPCT_HPP__ diff --git a/dpct/dpl_extras/algorithm.h b/dpct/dpl_extras/algorithm.h deleted file mode 100644 index 7c98b7a22..000000000 --- a/dpct/dpl_extras/algorithm.h +++ /dev/null @@ -1,2419 +0,0 @@ -//==---- algorithm.h ------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_ALGORITHM_H__ -#define __DPCT_ALGORITHM_H__ - -#include -#include -#include - -#include "functional.h" -#include "iterators.h" -#include "vector.h" - -namespace dpct { - -template -void replace_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p, - const T &new_value) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - std::transform( - std::forward(policy), first, last, mask, first, - internal::replace_if_fun::value_type, - Pred>(p, new_value)); -} - -template -Iter3 replace_copy_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 result, Pred p, const T &new_value) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - return std::transform( - std::forward(policy), first, last, mask, result, - internal::replace_if_fun::value_type, - Pred>(p, new_value)); -} - -template -internal::enable_if_hetero_execution_policy -remove_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using oneapi::dpl::make_zip_iterator; - using policy_type = typename std::decay::type; - using internal::__buffer; - using ValueType = typename std::iterator_traits::value_type; - - __buffer _tmp(std::distance(first, last)); - - auto end = std::copy_if( - policy, make_zip_iterator(first, mask), - make_zip_iterator(last, mask + std::distance(first, last)), - make_zip_iterator(_tmp.get(), oneapi::dpl::discard_iterator()), - internal::negate_predicate_key_fun(p)); - return std::copy(std::forward(policy), _tmp.get(), - std::get<0>(end.base()), first); -} - -template -typename std::enable_if::type>::value, - Iter1>::type -remove_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using oneapi::dpl::make_zip_iterator; - using policy_type = typename std::decay::type; - using ValueType = typename std::iterator_traits::value_type; - - std::vector _tmp(std::distance(first, last)); - - auto end = std::copy_if( - policy, make_zip_iterator(first, mask), - make_zip_iterator(last, mask + std::distance(first, last)), - make_zip_iterator(_tmp.begin(), oneapi::dpl::discard_iterator()), - internal::negate_predicate_key_fun(p)); - return std::copy(policy, _tmp.begin(), std::get<0>(end.base()), first); -} - -template -Iter3 remove_copy_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 result, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using oneapi::dpl::make_zip_iterator; - auto ret_val = std::remove_copy_if( - std::forward(policy), make_zip_iterator(first, mask), - make_zip_iterator(last, mask + std::distance(first, last)), - make_zip_iterator(result, oneapi::dpl::discard_iterator()), - internal::predicate_key_fun(p)); - return std::get<0>(ret_val.base()); -} - -template -std::pair unique(Policy &&policy, Iter1 keys_first, - Iter1 keys_last, Iter2 values_first, - BinaryPred binary_pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::unique( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first, values_first), - oneapi::dpl::make_zip_iterator( - keys_last, values_first + std::distance(keys_first, keys_last)), - internal::compare_key_fun(binary_pred)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_first, values_first), ret_val); - return std::make_pair(keys_first + n1, values_first + n1); -} - -template -std::pair unique(Policy &&policy, Iter1 keys_first, - Iter1 keys_last, Iter2 values_first) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using T = typename std::iterator_traits::value_type; - return unique(std::forward(policy), keys_first, keys_last, - values_first, std::equal_to()); -} - -template -std::pair unique_copy(Policy &&policy, Iter1 keys_first, - Iter1 keys_last, Iter2 values_first, - Iter3 keys_result, Iter4 values_result, - BinaryPred binary_pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::unique_copy( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first, values_first), - oneapi::dpl::make_zip_iterator( - keys_last, values_first + std::distance(keys_first, keys_last)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::unique_fun(binary_pred)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair unique_copy(Policy &&policy, Iter1 keys_first, - Iter1 keys_last, Iter2 values_first, - Iter3 keys_result, Iter4 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using T = typename std::iterator_traits::value_type; - auto comp = std::equal_to(); - return unique_copy(std::forward(policy), keys_first, keys_last, - values_first, keys_result, values_result, comp); -} - -template -Iter partition_point(Policy &&policy, Iter first, Iter last, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - if (std::is_partitioned(policy, first, last, p)) - return std::find_if_not(std::forward(policy), first, last, p); - else - return first; -} - -template -Iter3 copy_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 result, Pred pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::copy_if( - std::forward(policy), oneapi::dpl::make_zip_iterator(first, mask), - oneapi::dpl::make_zip_iterator(last, mask + std::distance(first, last)), - oneapi::dpl::make_zip_iterator(result, oneapi::dpl::discard_iterator()), - internal::predicate_key_fun(pred)); - return std::get<0>(ret_val.base()); -} - -template -Iter2 transform_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 result, - UnaryOperation unary_op, Pred pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using T = typename std::iterator_traits::value_type; - const auto n = std::distance(first, last); - std::for_each( - std::forward(policy), - oneapi::dpl::make_zip_iterator(first, result), - oneapi::dpl::make_zip_iterator(first, result) + n, - internal::transform_if_fun(pred, unary_op)); - return result + n; -} - -template -Iter3 transform_if(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 result, UnaryOperation unary_op, Pred pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using T = typename std::iterator_traits::value_type; - using Ref1 = typename std::iterator_traits::reference; - using Ref2 = typename std::iterator_traits::reference; - const auto n = std::distance(first, last); - std::for_each( - std::forward(policy), - oneapi::dpl::make_zip_iterator(first, mask, result), - oneapi::dpl::make_zip_iterator(first, mask, result) + n, - internal::transform_if_unary_zip_mask_fun( - pred, unary_op)); - return result + n; -} - -template -Iter4 transform_if(Policy &&policy, Iter1 first1, Iter1 last1, Iter2 first2, - Iter3 mask, Iter4 result, BinaryOperation binary_op, - Pred pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - const auto n = std::distance(first1, last1); - using ZipIterator = - typename oneapi::dpl::zip_iterator; - using T = typename std::iterator_traits::value_type; - std::for_each( - std::forward(policy), - oneapi::dpl::make_zip_iterator(first1, first2, mask, result), - oneapi::dpl::make_zip_iterator(last1, first2 + n, mask + n, result + n), - internal::transform_if_zip_mask_fun(pred, - binary_op)); - return result + n; -} - -template -void scatter(Policy &&policy, InputIter1 first, InputIter1 last, InputIter2 map, - OutputIter result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - oneapi::dpl::copy(policy, first, last, - oneapi::dpl::make_permutation_iterator(result, map)); -} - -template -OutputIter gather(Policy &&policy, InputIter1 map_first, InputIter1 map_last, - InputIter2 input_first, OutputIter result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto perm_begin = - oneapi::dpl::make_permutation_iterator(input_first, map_first); - const int n = ::std::distance(map_first, map_last); - - return oneapi::dpl::copy(policy, perm_begin, perm_begin + n, result); -} - -template -void scatter_if(Policy &&policy, InputIter1 first, InputIter1 last, - InputIter2 map, InputIter3 mask, OutputIter result, - Predicate pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - transform_if( - ::std::forward(policy), first, last, mask, - oneapi::dpl::make_permutation_iterator(result, map), - [=](auto &&v) { return v; }, [=](auto &&m) { return pred(m); }); -} - -template -void scatter_if(Policy &&policy, InputIter1 first, InputIter1 last, - InputIter2 map, InputIter3 mask, OutputIter result) { - scatter_if(::std::forward(policy), first, last, map, mask, result, - internal::no_op_fun()); -} - -template -OutputIter gather_if(Policy &&policy, InputIter1 map_first, InputIter1 map_last, - InputIter2 mask, InputIter3 input_first, OutputIter result, - Predicate pred) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same< - typename std::iterator_traits::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto perm_begin = - oneapi::dpl::make_permutation_iterator(input_first, map_first); - const int n = std::distance(map_first, map_last); - - return transform_if( - ::std::forward(policy), perm_begin, perm_begin + n, mask, result, - [=](auto &&v) { return v; }, [=](auto &&m) { return pred(m); }); -} - -template -OutputIter gather_if(Policy &&policy, InputIter1 map_first, InputIter1 map_last, - InputIter2 mask, InputIter3 input_first, - OutputIter result) { - return gather_if(::std::forward(policy), map_first, map_last, mask, - input_first, result, internal::no_op_fun()); -} - -template -std::pair -merge(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, Iter2 keys_first2, - Iter2 keys_last2, Iter3 values_first1, Iter4 values_first2, - Iter5 keys_result, Iter6 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto n1 = std::distance(keys_first1, keys_last1); - auto n2 = std::distance(keys_first2, keys_last2); - std::merge(std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator(keys_last1, values_first1 + n1), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator(keys_last2, values_first2 + n2), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun<>()); - return std::make_pair(keys_result + n1 + n2, values_result + n1 + n2); -} - -template -std::pair -merge(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, Iter2 keys_first2, - Iter2 keys_last2, Iter3 values_first1, Iter4 values_first2, - Iter5 keys_result, Iter6 values_result, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto n1 = std::distance(keys_first1, keys_last1); - auto n2 = std::distance(keys_first2, keys_last2); - std::merge(std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator(keys_last1, values_first1 + n1), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator(keys_last2, values_first2 + n2), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun(comp)); - return std::make_pair(keys_result + n1 + n2, values_result + n1 + n2); -} - -template -void iota(Policy &&policy, Iter first, Iter last, T init, T step) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using DiffSize = typename std::iterator_traits::difference_type; - std::transform( - std::forward(policy), oneapi::dpl::counting_iterator(0), - oneapi::dpl::counting_iterator(std::distance(first, last)), - first, internal::sequence_fun(init, step)); -} - -template -void iota(Policy &&policy, Iter first, Iter last, T init) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - iota(std::forward(policy), first, last, init, T(1)); -} - -template -void iota(Policy &&policy, Iter first, Iter last) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using DiffSize = typename std::iterator_traits::difference_type; - iota(std::forward(policy), first, last, DiffSize(0), DiffSize(1)); -} - -template -void sort(Policy &&policy, Iter1 keys_first, Iter1 keys_last, - Iter2 values_first, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto first = oneapi::dpl::make_zip_iterator(keys_first, values_first); - auto last = first + std::distance(keys_first, keys_last); - std::sort(std::forward(policy), first, last, - internal::compare_key_fun(comp)); -} - -template -void sort(Policy &&policy, Iter1 keys_first, Iter1 keys_last, - Iter2 values_first) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - sort(std::forward(policy), keys_first, keys_last, values_first, - internal::__less()); -} - -template -void stable_sort(Policy &&policy, Iter1 keys_first, Iter1 keys_last, - Iter2 values_first, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - std::stable_sort( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first, values_first), - oneapi::dpl::make_zip_iterator( - keys_last, values_first + std::distance(keys_first, keys_last)), - internal::compare_key_fun(comp)); -} - -template -void stable_sort(Policy &&policy, Iter1 keys_first, Iter1 keys_last, - Iter2 values_first) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - stable_sort(std::forward(policy), keys_first, keys_last, values_first, - internal::__less()); -} - -template -void for_each_index(Policy &&policy, Iter first, Iter last, Operator unary_op) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - using DiffSize = typename std::iterator_traits::difference_type; - std::transform( - std::forward(policy), oneapi::dpl::counting_iterator(0), - oneapi::dpl::counting_iterator(std::distance(first, last)), - first, unary_op); -} - -template -std::pair -set_intersection(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, Iter3 values_first1, - Iter4 keys_result, Iter5 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_intersection( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, - oneapi::dpl::discard_iterator()), - oneapi::dpl::make_zip_iterator(keys_last2, - oneapi::dpl::discard_iterator()), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun<>()); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair -set_intersection(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, Iter3 values_first1, - Iter4 keys_result, Iter5 values_result, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_intersection( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, - oneapi::dpl::discard_iterator()), - oneapi::dpl::make_zip_iterator(keys_last2, - oneapi::dpl::discard_iterator()), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun(comp)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair -set_symmetric_difference(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, - Iter3 values_first1, Iter4 values_first2, - Iter5 keys_result, Iter6 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_symmetric_difference( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun<>()); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair -set_symmetric_difference(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, - Iter3 values_first1, Iter4 values_first2, - Iter5 keys_result, Iter6 values_result, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_symmetric_difference( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun(comp)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair -set_difference(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, Iter3 values_first1, - Iter4 values_first2, Iter5 keys_result, Iter6 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_difference( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun<>()); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -std::pair set_difference(Policy &&policy, Iter1 keys_first1, - Iter1 keys_last1, Iter2 keys_first2, - Iter2 keys_last2, Iter3 values_first1, - Iter4 values_first2, Iter5 keys_result, - Iter6 values_result, Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_difference( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun(comp)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -internal::enable_if_execution_policy> -set_union(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, Iter3 values_first1, - Iter4 values_first2, Iter5 keys_result, Iter6 values_result) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_union( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun<>()); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -internal::enable_if_execution_policy> -set_union(Policy &&policy, Iter1 keys_first1, Iter1 keys_last1, - Iter2 keys_first2, Iter2 keys_last2, Iter3 values_first1, - Iter4 values_first2, Iter5 keys_result, Iter6 values_result, - Comp comp) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::set_union( - std::forward(policy), - oneapi::dpl::make_zip_iterator(keys_first1, values_first1), - oneapi::dpl::make_zip_iterator( - keys_last1, values_first1 + std::distance(keys_first1, keys_last1)), - oneapi::dpl::make_zip_iterator(keys_first2, values_first2), - oneapi::dpl::make_zip_iterator( - keys_last2, values_first2 + std::distance(keys_first2, keys_last2)), - oneapi::dpl::make_zip_iterator(keys_result, values_result), - internal::compare_key_fun(comp)); - auto n1 = std::distance( - oneapi::dpl::make_zip_iterator(keys_result, values_result), ret_val); - return std::make_pair(keys_result + n1, values_result + n1); -} - -template -internal::enable_if_execution_policy> -stable_partition_copy(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 out_true, Iter4 out_false, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - auto ret_val = std::partition_copy( - std::forward(policy), oneapi::dpl::make_zip_iterator(first, mask), - oneapi::dpl::make_zip_iterator(last, mask + std::distance(first, last)), - oneapi::dpl::make_zip_iterator(out_true, oneapi::dpl::discard_iterator()), - oneapi::dpl::make_zip_iterator(out_false, - oneapi::dpl::discard_iterator()), - internal::predicate_key_fun(p)); - return std::make_pair(std::get<0>(ret_val.first.base()), - std::get<0>(ret_val.second.base())); -} - -template -internal::enable_if_execution_policy> -stable_partition_copy(Policy &&policy, Iter1 first, Iter1 last, Iter3 out_true, - Iter4 out_false, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - return std::partition_copy(std::forward(policy), first, last, - out_true, out_false, p); -} - -template -internal::enable_if_execution_policy> -partition_copy(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, - Iter3 out_true, Iter4 out_false, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - return stable_partition_copy(std::forward(policy), first, last, mask, - out_true, out_false, p); -} - -template -internal::enable_if_hetero_execution_policy -stable_partition(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - typedef typename std::decay::type policy_type; - internal::__buffer::value_type> _tmp( - std::distance(first, last)); - - std::copy(policy, mask, mask + std::distance(first, last), _tmp.get()); - - auto ret_val = - std::stable_partition(std::forward(policy), - oneapi::dpl::make_zip_iterator(first, _tmp.get()), - oneapi::dpl::make_zip_iterator( - last, _tmp.get() + std::distance(first, last)), - internal::predicate_key_fun(p)); - return std::get<0>(ret_val.base()); -} - -template -typename std::enable_if::type>::value, - Iter1>::type -stable_partition(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - typedef typename std::decay::type policy_type; - std::vector::value_type> _tmp( - std::distance(first, last)); - - std::copy(policy, mask, mask + std::distance(first, last), _tmp.begin()); - - auto ret_val = std::stable_partition( - std::forward(policy), - oneapi::dpl::make_zip_iterator(first, _tmp.begin()), - oneapi::dpl::make_zip_iterator(last, - _tmp.begin() + std::distance(first, last)), - internal::predicate_key_fun(p)); - return std::get<0>(ret_val.base()); -} - -template -internal::enable_if_execution_policy -partition(Policy &&policy, Iter1 first, Iter1 last, Iter2 mask, Pred p) { - static_assert( - std::is_same::iterator_category, - std::random_access_iterator_tag>::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - "Iterators passed to algorithms must be random-access iterators."); - return stable_partition(std::forward(policy), first, last, mask, p); -} - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value> -sort_pairs(Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, bool descending = false, - int begin_bit = 0, - int end_bit = - sizeof(typename ::std::iterator_traits::value_type) * 8); - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value> -sort_keys(Policy &&policy, Iter1 keys_in, Iter2 keys_out, ::std::int64_t n, - bool descending = false, int begin_bit = 0, - int end_bit = - sizeof(typename ::std::iterator_traits::value_type) * 8); - -namespace internal { - -// Transforms key to a specific bit range and sorts the transformed key -template -inline void transform_and_sort(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - ::std::int64_t n, bool descending, int begin_bit, - int end_bit) { - using key_t_value_t = typename std::iterator_traits::value_type; - auto trans_key = - translate_key(begin_bit, end_bit); - - // Use of the comparison operator that is not simply std::greater() or - // std::less() will result in - // not using radix sort which will cost some performance. However, this is - // necessary to provide the transformation of the key to the bitrange - // desired. - auto partial_sort_with_comp = [&](const auto &comp) { - return oneapi::dpl::partial_sort_copy( - std::forward(policy), keys_in, keys_in + n, keys_out, - keys_out + n, [=](const auto a, const auto b) { - return comp(trans_key(a), trans_key(b)); - }); - }; - if (descending) - partial_sort_with_comp(::std::greater()); - else - partial_sort_with_comp(::std::less()); -} - -template -inline void sort_only(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - ::std::int64_t n, bool descending) { - using key_t_value_t = typename ::std::iterator_traits::value_type; - - if constexpr (::std::is_floating_point::value) { - if (descending) { - // Comparison operator that is not std::greater() ensures stability of - // -0.0 and 0.0 - // at the cost of some performance because radix sort will not be used. - auto comp_descending = [=](const auto a, const auto b) { return a > b; }; - - oneapi::dpl::partial_sort_copy(::std::forward(policy), keys_in, - keys_in + n, keys_out, keys_out + n, - comp_descending); - } else { - // Comparison operator that is not std::less() ensures stability of -0.0 - // and 0.0 - // at the cost of some performance because radix sort will not be used. - auto comp_ascending = [=](const auto a, const auto b) { return a < b; }; - - oneapi::dpl::partial_sort_copy(::std::forward(policy), keys_in, - keys_in + n, keys_out, keys_out + n, - comp_ascending); - } - } else { - if (descending) { - oneapi::dpl::partial_sort_copy(::std::forward(policy), keys_in, - keys_in + n, keys_out, keys_out + n, - ::std::greater()); - } else { - - oneapi::dpl::partial_sort_copy(::std::forward(policy), keys_in, - keys_in + n, keys_out, keys_out + n); - } - } -} - -// Transforms key from a pair to a specific bit range and sorts the pairs by the -// transformed key -template -inline void -transform_and_sort_pairs(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - Iter3 values_in, Iter4 values_out, ::std::int64_t n, - bool descending, int begin_bit, int end_bit) { - using key_t_value_t = typename std::iterator_traits::value_type; - auto zip_input = oneapi::dpl::zip_iterator(keys_in, values_in); - auto zip_output = oneapi::dpl::zip_iterator(keys_out, values_out); - auto trans_key = - translate_key(begin_bit, end_bit); - - // Use of the comparison operator that is not simply std::greater() or - // std::less() will result in - // not using radix sort which will cost some performance. However, this is - // necessary to provide the transformation of the key to the bitrange desired - // and also to select the key from the zipped pair. - auto load_val = [=](const auto a) { return trans_key(std::get<0>(a)); }; - - auto partial_sort_with_comp = [&](const auto &comp) { - return oneapi::dpl::partial_sort_copy( - std::forward(policy), zip_input, zip_input + n, zip_output, - zip_output + n, [=](const auto a, const auto b) { - return comp(load_val(a), load_val(b)); - }); - }; - if (descending) - partial_sort_with_comp(::std::greater()); - else - partial_sort_with_comp(::std::less()); -} - -template -inline void sort_only_pairs(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - Iter3 values_in, Iter4 values_out, ::std::int64_t n, - bool descending) { - using key_t_value_t = typename ::std::iterator_traits::value_type; - auto zip_input = oneapi::dpl::zip_iterator(keys_in, values_in); - auto zip_output = oneapi::dpl::zip_iterator(keys_out, values_out); - - // Use of the comparison operator that is not simply std::greater() or - // std::less() will result in - // not using radix sort which will cost some performance. However, this is - // necessary to select the key from the zipped pair. - auto load_val = [=](const auto a) { return std::get<0>(a); }; - - auto partial_sort_with_comp = [&](const auto &comp) { - return oneapi::dpl::partial_sort_copy( - std::forward(policy), zip_input, zip_input + n, zip_output, - zip_output + n, [=](const auto a, const auto b) { - return comp(load_val(a), load_val(b)); - }); - }; - if (descending) - partial_sort_with_comp(::std::greater()); - else - partial_sort_with_comp(::std::less()); -} - -// overload for Iter2 != std::nullptr_t -template -typename ::std::enable_if::value>::type -sort_pairs_impl(Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, bool descending, - int begin_bit, int end_bit) { - using key_t_value_t = typename ::std::iterator_traits::value_type; - - int clipped_begin_bit = ::std::max(begin_bit, 0); - int clipped_end_bit = - ::std::min((::std::uint64_t)end_bit, sizeof(key_t_value_t) * 8); - int num_bytes = (clipped_end_bit - clipped_begin_bit - 1) / 8 + 1; - - auto transform_and_sort_pairs_f = [&](auto x) { - using T = typename ::std::decay_t; - internal::transform_and_sort_pairs( - ::std::forward(policy), keys_in, keys_out, values_in, - values_out, n, descending, clipped_begin_bit, clipped_end_bit); - }; - - if (clipped_end_bit - clipped_begin_bit == sizeof(key_t_value_t) * 8) { - internal::sort_only_pairs(::std::forward(policy), keys_in, keys_out, - values_in, values_out, n, descending); - } else if (num_bytes == 1) { - transform_and_sort_pairs_f.template operator()(0); - } else if (num_bytes == 2) { - transform_and_sort_pairs_f.template operator()(0); - } else if (num_bytes <= 4) { - transform_and_sort_pairs_f.template operator()(0); - } else // if (num_bytes <= 8) - { - transform_and_sort_pairs_f.template operator()<::std::uint64_t>(0); - } -} - -// overload for Iter2 == std::nullptr_t -template -typename ::std::enable_if<::std::is_null_pointer::value>::type -sort_pairs_impl(Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, bool descending, - int begin_bit, int end_bit) { - // create temporary keys_out to discard, memory footprint could be improved by - // a specialized iterator with a single - // unchanging dummy Iter1 element - using key_t_value_t = typename std::iterator_traits::value_type; - sycl::buffer temp_keys_out{sycl::range<1>(n)}; - internal::sort_pairs_impl(std::forward(policy), keys_in, - oneapi::dpl::begin(temp_keys_out), values_in, - values_out, n, descending, begin_bit, end_bit); -} - -template -inline void segmented_sort_pairs_by_parallel_sorts( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter4 values_in, - Iter3 values_out, ::std::int64_t n, ::std::int64_t nsegments, - Iter5 begin_offsets, Iter5 end_offsets, bool descending = false, - int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - using offset_type = typename ::std::iterator_traits::value_type; - ::std::vector host_accessible_offset_starts(nsegments); - ::std::vector host_accessible_offset_ends(nsegments); - // make offsets accessible on host - ::std::copy(policy, begin_offsets, begin_offsets + nsegments, - host_accessible_offset_starts.begin()); - ::std::copy(policy, end_offsets, end_offsets + nsegments, - host_accessible_offset_ends.begin()); - - for (::std::uint64_t i = 0; i < nsegments; i++) { - ::std::uint64_t segment_begin = host_accessible_offset_starts[i]; - ::std::uint64_t segment_end = - ::std::min(n, (::std::int64_t)host_accessible_offset_ends[i]); - if (segment_begin < segment_end) { - ::dpct::sort_pairs( - policy, keys_in + segment_begin, keys_out + segment_begin, - values_in + segment_begin, values_out + segment_begin, - segment_end - segment_begin, descending, begin_bit, end_bit); - } - } -} - -template -inline void segmented_sort_keys_by_parallel_sorts( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, ::std::int64_t n, - ::std::int64_t nsegments, Iter3 begin_offsets, Iter3 end_offsets, - bool descending = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - using offset_type = typename ::std::iterator_traits::value_type; - ::std::vector host_accessible_offset_starts(nsegments); - ::std::vector host_accessible_offset_ends(nsegments); - // make offsets accessible on host - ::std::copy(policy, begin_offsets, begin_offsets + nsegments, - host_accessible_offset_starts.begin()); - ::std::copy(policy, end_offsets, end_offsets + nsegments, - host_accessible_offset_ends.begin()); - - for (::std::uint64_t i = 0; i < nsegments; i++) { - ::std::uint64_t segment_begin = host_accessible_offset_starts[i]; - ::std::uint64_t segment_end = - ::std::min(n, (::std::int64_t)host_accessible_offset_ends[i]); - if (segment_begin < segment_end) { - ::dpct::sort_keys(policy, keys_in + segment_begin, - keys_out + segment_begin, segment_end - segment_begin, - descending, begin_bit, end_bit); - } - } -} - -template -inline void segmented_sort_pairs_by_parallel_for_of_sorts( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, ::std::int64_t nsegments, - Iter5 begin_offsets, Iter5 end_offsets, bool descending = false, - int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - policy.queue().submit([&](sycl::handler &cgh) { - cgh.parallel_for(nsegments, [=](sycl::id<1> i) { - ::std::uint64_t segment_begin = begin_offsets[i]; - ::std::uint64_t segment_end = - ::std::min(n, (::std::int64_t)end_offsets[i]); - if (segment_begin == segment_end) { - return; - } - ::dpct::sort_pairs(::std::execution::seq, keys_in + segment_begin, - keys_out + segment_begin, values_in + segment_begin, - values_out + segment_begin, - segment_end - segment_begin, descending, begin_bit, - end_bit); - }); - }); - policy.queue().wait(); -} - -template -inline void segmented_sort_keys_by_parallel_for_of_sorts( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, ::std::int64_t n, - ::std::int64_t nsegments, Iter3 begin_offsets, Iter3 end_offsets, - bool descending = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - policy.queue().submit([&](sycl::handler &cgh) { - cgh.parallel_for(nsegments, [=](sycl::id<1> i) { - ::std::uint64_t segment_begin = begin_offsets[i]; - ::std::uint64_t segment_end = - ::std::min(n, (::std::int64_t)end_offsets[i]); - if (segment_begin == segment_end) { - return; - } - ::dpct::sort_keys(::std::execution::seq, keys_in + segment_begin, - keys_out + segment_begin, segment_end - segment_begin, - descending, begin_bit, end_bit); - }); - }); - policy.queue().wait(); -} - -template -inline void mark_segments(Policy &&policy, OffsetIteratorT begin_offsets, - OffsetIteratorT end_offsets, ::std::int64_t n, - ::std::int64_t nsegments, - sycl::buffer<::std::size_t, 1> segments) { - - ::std::size_t work_group_size = - policy.queue() - .get_device() - .template get_info(); - - auto sg_sizes = policy.queue() - .get_device() - .template get_info(); - ::std::size_t sub_group_size = sg_sizes.empty() ? 0 : sg_sizes.back(); - - float avg_seg_size = (float)n / (float)nsegments; - if (avg_seg_size > work_group_size) { - // If average segment size is larger than workgroup, use workgroup to - // coordinate to mark segments - policy.queue() - .submit([&](sycl::handler &h) { - auto segments_acc = segments.get_access(h); - h.parallel_for(work_group_size, ([=](sycl::id<1> id) { - for (::std::size_t seg = 0; seg < nsegments; seg++) { - ::std::size_t i = begin_offsets[seg]; - ::std::size_t end = end_offsets[seg]; - while (i + id < end) { - segments_acc[i + id] = seg; - i += work_group_size; - } - } - })); - }) - .wait(); - } else if (sub_group_size > 0 && avg_seg_size > sub_group_size / 2) { - // If average segment size is larger than half a subgroup, use subgroup to - // coordinate to mark segments - policy.queue() - .submit([&](sycl::handler &h) { - auto segments_acc = segments.get_access(h); - h.parallel_for( - sycl::nd_range<1>{work_group_size, work_group_size}, - ([=](sycl::nd_item<1> item) { - auto sub_group = item.get_sub_group(); - ::std::size_t num_subgroups = - sub_group.get_group_range().size(); - ::std::size_t local_size = sub_group.get_local_range().size(); - - ::std::size_t sub_group_id = sub_group.get_group_id(); - while (sub_group_id < nsegments) { - ::std::size_t subgroup_local_id = sub_group.get_local_id(); - ::std::size_t i = begin_offsets[sub_group_id]; - ::std::size_t end = end_offsets[sub_group_id]; - while (i + subgroup_local_id < end) { - segments_acc[i + subgroup_local_id] = sub_group_id; - i += local_size; - } - sub_group_id += num_subgroups; - } - })); - }) - .wait(); - } else { - // If average segment size is small as compared to subgroup, use single - // work item to mark each segment - policy.queue() - .submit([&](sycl::handler &h) { - auto segments_acc = segments.get_access(h); - h.parallel_for(nsegments, ([=](sycl::id<1> seg) { - for (::std::size_t i = begin_offsets[seg]; - i < end_offsets[seg]; i++) { - segments_acc[i] = seg; - } - })); - }) - .wait(); - } -} - -// The dpl_histogram namespace contains a temporary preview of an upcoming -// oneDPL histogram API. This namespace will be removed and replaced with -// corresponding calls to oneapi::dpl::histogram() -namespace dpl_histogram { - -template -constexpr inline auto __ceiling_div(const T1 &number, const T2 &divisor) { - return (number - 1) / divisor + 1; -} - -template -struct __evenly_divided_binhash_impl {}; - -template -struct __evenly_divided_binhash_impl { - T __minimum; - ::std::uint32_t __num_bins; - T __scale; - T __maximum; - __evenly_divided_binhash_impl(const T &min, const T &max, - const ::std::uint32_t &num_bins) - : __minimum(min), __maximum(max), __num_bins(num_bins), - __scale(T(num_bins) / (max - min)) {} - template std::uint32_t operator()(T2 &&value) const { - return ::std::uint32_t((::std::forward(value) - __minimum) * __scale); - } - - template bool is_valid(const T2 &value) const { - return value >= __minimum && value < __maximum; - } -}; - -// non floating point type -template -struct __evenly_divided_binhash_impl { - T __minimum; - ::std::uint32_t __num_bins; - T __range_size; - __evenly_divided_binhash_impl(const T &min, const T &max, - const ::std::uint32_t &num_bins) - : __minimum(min), __num_bins(num_bins), __range_size(max - min) {} - template ::std::uint32_t operator()(T2 &&value) const { - return ::std::uint32_t( - ((::std::uint64_t(::std::forward(value)) - __minimum) * - ::std::uint64_t(__num_bins)) / - __range_size); - } - - template bool is_valid(const T2 &value) const { - return value >= __minimum && value < __minimum + __range_size; - } -}; - -template -using __evenly_divided_binhash = - __evenly_divided_binhash_impl>; - -template struct __custom_range_binhash { - Range __boundaries; - __custom_range_binhash(Range boundaries) : __boundaries(boundaries) {} - - template ::std::uint32_t operator()(T &&value) const { - return (::std::upper_bound(__boundaries.begin(), __boundaries.end(), - ::std::forward(value)) - - __boundaries.begin()) - - 1; - } - - template bool is_valid(const T2 &value) const { - return value >= __boundaries[0] && - value < __boundaries[__boundaries.size() - 1]; - } -}; - -template -inline void __clear_wglocal_histograms(const HistAccessor &local_histogram, - const OffsetT &offset, - const Size &num_bins, - const sycl::nd_item<1> &self_item) { - ::std::uint32_t gSize = self_item.get_local_range()[0]; - ::std::uint32_t self_lidx = self_item.get_local_id(0); - ::std::uint8_t factor = __ceiling_div(num_bins, gSize); - ::std::uint8_t k; - _DPCT_PRAGMA_UNROLL - for (k = 0; k < factor - 1; k++) { - local_histogram[offset + gSize * k + self_lidx] = 0; - } - if (gSize * k + self_lidx < num_bins) { - local_histogram[offset + gSize * k + self_lidx] = 0; - } - self_item.barrier(sycl::access::fence_space::local_space); -} - -template -inline void __accum_local_register_iter(const Iter1 &in_acc, - const ::std::size_t &index, - HistReg *histogram, BinFunc func) { - const auto &x = in_acc[index]; - if (func.is_valid(x)) { - BinIdxType c = func(x); - histogram[c]++; - } -} - -template -inline void __accum_local_atomics_iter(const Iter1 &in_acc, - const ::std::size_t &index, - const HistAccessor &wg_local_histogram, - const OffsetT &offset, BinFunc func) { - using __histo_value_type = typename HistAccessor::value_type; - const auto &x = in_acc[index]; - if (func.is_valid(x)) { - BinIdxType c = func(x); - sycl::atomic_ref<__histo_value_type, sycl::memory_order::relaxed, - sycl::memory_scope::work_group, AddressSpace> - local_bin(wg_local_histogram[offset + c]); - local_bin++; - } -} - -template -inline void __reduce_out_histograms(const HistAccessorIn &in_histogram, - const OffsetT &offset, - const HistAccessorOut &out_histogram, - const Size &num_bins, - const sycl::nd_item<1> &self_item) { - ::std::uint32_t gSize = self_item.get_local_range()[0]; - ::std::uint32_t self_lidx = self_item.get_local_id(0); - ::std::uint8_t factor = __ceiling_div(num_bins, gSize); - ::std::uint8_t k; - - _DPCT_PRAGMA_UNROLL - for (k = 0; k < factor - 1; k++) { - sycl::atomic_ref - global_bin(out_histogram[gSize * k + self_lidx]); - global_bin += in_histogram[offset + gSize * k + self_lidx]; - } - if (gSize * k + self_lidx < num_bins) { - sycl::atomic_ref - global_bin(out_histogram[gSize * k + self_lidx]); - global_bin += in_histogram[offset + gSize * k + self_lidx]; - } -} - -template <::std::uint16_t ItersPerWorkItem, ::std::uint8_t BinsPerWorkItem, - typename BinType, typename Policy, typename Range1, typename Range2, - typename Size, typename IdxHashFunc, typename... Range3> -inline void __histogram_general_registers_local_reduction( - Policy &&policy, ::std::uint16_t work_group_size, Range1 &&input, - Range2 &&bins, const Size &num_bins, IdxHashFunc func, - Range3 &&...opt_range) { - const ::std::size_t N = input.size(); - using __local_histogram_type = ::std::uint32_t; - using __private_histogram_type = ::std::uint16_t; - - ::std::size_t segments = __ceiling_div(N, work_group_size * ItersPerWorkItem); - auto e = policy.queue().submit([&](auto &h) { - // Temporary use of stable non-public API from oneDPL, this function will - // be replaced with oneDPL call in an upcoming release. - oneapi::dpl::__ranges::__require_access(h, input, bins, opt_range...); - sycl::local_accessor<__local_histogram_type, 1> local_histogram( - sycl::range(num_bins), h); - h.parallel_for( - sycl::nd_range<1>(segments * work_group_size, work_group_size), - [=](sycl::nd_item<1> __self_item) { - using __bin_idx_type = ::std::uint8_t; - const ::std::size_t __self_lidx = __self_item.get_local_id(0); - const ::std::size_t __wgroup_idx = __self_item.get_group(0); - const ::std::size_t __seg_start = - work_group_size * ItersPerWorkItem * __wgroup_idx; - - __clear_wglocal_histograms(local_histogram, 0, num_bins, __self_item); - __private_histogram_type histogram[BinsPerWorkItem]; - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t k = 0; k < BinsPerWorkItem; k++) { - histogram[k] = 0; - } - - if (__seg_start + work_group_size * ItersPerWorkItem < N) { - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t idx = 0; idx < ItersPerWorkItem; idx++) { - __accum_local_register_iter<__bin_idx_type>( - input, __seg_start + idx * work_group_size + __self_lidx, - histogram, func); - } - } else { - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t idx = 0; idx < ItersPerWorkItem; idx++) { - ::std::size_t __val_idx = - __seg_start + idx * work_group_size + __self_lidx; - if (__val_idx < N) { - __accum_local_register_iter<__bin_idx_type>(input, __val_idx, - histogram, func); - } - } - } - - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t k = 0; k < num_bins; k++) { - sycl::atomic_ref<__local_histogram_type, - sycl::memory_order::relaxed, - sycl::memory_scope::work_group, - sycl::access::address_space::local_space> - local_bin(local_histogram[k]); - local_bin += histogram[k]; - } - - __self_item.barrier(sycl::access::fence_space::local_space); - - __reduce_out_histograms(local_histogram, 0, bins, num_bins, - __self_item); - }); - }); - e.wait(); -} - -template <::std::uint16_t ItersPerWorkItem, typename BinType, typename Policy, - typename Range1, typename Range2, typename Size, typename IdxHashFunc, - typename... Range3> -inline void __histogram_general_local_atomics(Policy &&policy, - ::std::uint16_t work_group_size, - Range1 &&input, Range2 &&bins, - const Size &num_bins, - IdxHashFunc func, - Range3 &&...opt_range) { - const ::std::size_t N = input.size(); - ::std::size_t segments = __ceiling_div(N, work_group_size * ItersPerWorkItem); - auto e = policy.queue().submit([&](auto &h) { - // Temporary use of stable non-public API from oneDPL, this function will - // be replaced with oneDPL call in an upcoming release. - oneapi::dpl::__ranges::__require_access(h, input, bins, opt_range...); - sycl::local_accessor<::std::uint32_t, 1> local_histogram( - sycl::range(num_bins), h); - h.parallel_for( - sycl::nd_range<1>(segments * work_group_size, work_group_size), - [=](sycl::nd_item<1> __self_item) { - using __bin_idx_type = ::std::uint16_t; - constexpr auto __atomic_address_space = - sycl::access::address_space::local_space; - const ::std::size_t __self_lidx = __self_item.get_local_id(0); - const ::std::uint32_t __wgroup_idx = __self_item.get_group(0); - const ::std::size_t __seg_start = - work_group_size * __wgroup_idx * ItersPerWorkItem; - - __clear_wglocal_histograms(local_histogram, 0, num_bins, __self_item); - - if (__seg_start + work_group_size * ItersPerWorkItem < N) { - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t idx = 0; idx < ItersPerWorkItem; idx++) { - __accum_local_atomics_iter<__bin_idx_type, - __atomic_address_space>( - input, __seg_start + idx * work_group_size + __self_lidx, - local_histogram, 0, func); - } - } else { - _DPCT_PRAGMA_UNROLL - for (::std::uint8_t idx = 0; idx < ItersPerWorkItem; idx++) { - ::std::size_t __val_idx = - __seg_start + idx * work_group_size + __self_lidx; - if (__val_idx < N) { - __accum_local_atomics_iter<__bin_idx_type, - __atomic_address_space>( - input, __val_idx, local_histogram, 0, func); - } - } - } - __self_item.barrier(sycl::access::fence_space::local_space); - - __reduce_out_histograms(local_histogram, 0, bins, num_bins, - __self_item); - }); - }); - - e.wait(); -} - -template <::std::uint16_t __min_iters_per_work_item, typename BinType, - typename Policy, typename Range1, typename Range2, typename Size, - typename IdxHashFunc, typename... Range3> -inline void __histogram_general_private_global_atomics( - Policy &&policy, ::std::uint16_t work_group_size, Range1 &&input, - Range2 &&bins, const Size &num_bins, IdxHashFunc func, - Range3 &&...opt_range) { - - const ::std::size_t N = input.size(); - auto __global_mem_size = - policy.queue() - .get_device() - .template get_info(); - const ::std::size_t max_segments = - ::std::min(__global_mem_size / (num_bins * sizeof(BinType)), - __ceiling_div(N, work_group_size * __min_iters_per_work_item)); - const ::std::size_t iters_per_work_item = - __ceiling_div(N, max_segments * work_group_size); - ::std::size_t segments = - __ceiling_div(N, work_group_size * iters_per_work_item); - - sycl::buffer private_histograms( - sycl::range<1>(segments * num_bins)); - - auto e = policy.queue().submit([&](auto &h) { - // Temporary use of stable non-public API from oneDPL, this function will - // be replaced with oneDPL call in an upcoming release. - oneapi::dpl::__ranges::__require_access(h, input, bins, opt_range...); - sycl::accessor hacc_private(private_histograms, h, sycl::read_write, - sycl::no_init); - h.parallel_for( - sycl::nd_range<1>(segments * work_group_size, work_group_size), - [=](sycl::nd_item<1> __self_item) { - using __bin_idx_type = ::std::uint32_t; - constexpr auto __atomic_address_space = - sycl::access::address_space::global_space; - const ::std::size_t __self_lidx = __self_item.get_local_id(0); - const ::std::size_t __wgroup_idx = __self_item.get_group(0); - const ::std::size_t __seg_start = - work_group_size * iters_per_work_item * __wgroup_idx; - - __clear_wglocal_histograms(hacc_private, __wgroup_idx * num_bins, - num_bins, __self_item); - if (__seg_start + work_group_size * iters_per_work_item < N) { - for (::std::size_t idx = 0; idx < iters_per_work_item; idx++) { - __accum_local_atomics_iter<__bin_idx_type, - __atomic_address_space>( - input, __seg_start + idx * work_group_size + __self_lidx, - hacc_private, __wgroup_idx * num_bins, func); - } - } else { - for (::std::size_t idx = 0; idx < iters_per_work_item; idx++) { - ::std::size_t __val_idx = - __seg_start + idx * work_group_size + __self_lidx; - if (__val_idx < N) { - __accum_local_atomics_iter<__bin_idx_type, - __atomic_address_space>( - input, __val_idx, hacc_private, __wgroup_idx * num_bins, - func); - } - } - } - __self_item.barrier(sycl::access::fence_space::local_space); - - __reduce_out_histograms(hacc_private, - __wgroup_idx * num_bins, bins, - num_bins, __self_item); - }); - }); - e.wait(); -} - -template -inline Iter2 -__histogram_general_select_best(Policy &&policy, Iter1 first, Iter1 last, - Iter2 histogram_first, const Size &num_bins, - IdxHashFunc func, Range &&...opt_range) { - using __histo_value_type = typename ::std::iterator_traits::value_type; - auto __local_mem_size = - policy.queue() - .get_device() - .template get_info(); - constexpr ::std::uint8_t __max_registers = 16; - - // Temporary use of stable non-public API from oneDPL, this function will be - // replaced with oneDPL call in an upcoming release. - auto keep_bins = oneapi::dpl::__ranges::__get_sycl_range< - oneapi::dpl::__par_backend_hetero::access_mode::write, Iter2>(); - auto bins_buf = keep_bins(histogram_first, histogram_first + num_bins); - - oneapi::dpl::fill(policy, bins_buf.all_view().begin(), - bins_buf.all_view().end(), __histo_value_type(0)); - auto N = last - first; - if (N > 0) { - // Temporary use of stable non-public API from oneDPL, this function will - // be replaced with oneDPL call in an upcoming release. - auto keep_input = oneapi::dpl::__ranges::__get_sycl_range< - oneapi::dpl::__par_backend_hetero::access_mode::read, Iter1>(); - auto input_buf = keep_input(first, last); - - ::std::size_t max_work_group_size = - policy.queue() - .get_device() - .template get_info(); - ::std::size_t work_group_size = - ::std::min(max_work_group_size, ::std::size_t(1024)); - - if (num_bins < __max_registers) { - - // If bins fit into registers, use register private accumulation - __histogram_general_registers_local_reduction<32, 16, __histo_value_type>( - ::std::forward(policy), work_group_size, input_buf.all_view(), - bins_buf.all_view(), num_bins, func, - ::std::forward(opt_range)...); - } else if (num_bins * sizeof(__histo_value_type) < __local_mem_size) { - // If bins fit into SLM, use local atomics - - // Experimentally determined iters per work-item - if (N <= 524288) { - __histogram_general_local_atomics<4, __histo_value_type>( - ::std::forward(policy), work_group_size, - input_buf.all_view(), bins_buf.all_view(), num_bins, func, - ::std::forward(opt_range)...); - } else { - __histogram_general_local_atomics<32, __histo_value_type>( - ::std::forward(policy), work_group_size, - input_buf.all_view(), bins_buf.all_view(), num_bins, func, - ::std::forward(opt_range)...); - } - } else // Otherwise, use global atomics (private copies per workgroup) - { - // Experimentally determined iters per work-item - if (N <= 524288) { - __histogram_general_private_global_atomics<4, __histo_value_type>( - ::std::forward(policy), work_group_size, - input_buf.all_view(), bins_buf.all_view(), num_bins, func, - ::std::forward(opt_range)...); - } else { - __histogram_general_private_global_atomics<32, __histo_value_type>( - ::std::forward(policy), work_group_size, - input_buf.all_view(), bins_buf.all_view(), num_bins, func, - ::std::forward(opt_range)...); - } - } - } - return histogram_first + num_bins; -} - -template -inline ::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value, - Iter2> -histogram(Policy &&policy, Iter1 first, Iter1 last, Iter2 histogram_first, - const Size &num_bins, const T &first_bin_min_val, - const T &last_bin_max_val) { - return __histogram_general_select_best( - ::std::forward(policy), first, last, histogram_first, num_bins, - __evenly_divided_binhash(first_bin_min_val, last_bin_max_val, - num_bins)); -} - -template -inline ::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value, - Iter2> -histogram(Policy &&policy, Iter1 first, Iter1 last, Iter2 histogram_first, - Iter3 boundary_first, Iter3 boundary_last) { - // Temporary use of stable non-public API from oneDPL, this function will be - // replaced with oneDPL call in an upcoming release. - auto keep_boundaries = oneapi::dpl::__ranges::__get_sycl_range< - oneapi::dpl::__par_backend_hetero::access_mode::read, Iter3>(); - auto boundary_buf = keep_boundaries(boundary_first, boundary_last); - - return __histogram_general_select_best( - ::std::forward(policy), first, last, histogram_first, - (boundary_last - boundary_first) - 1, - __custom_range_binhash{boundary_buf.all_view()}, boundary_buf.all_view()); -} -} // end namespace dpl_histogram - -} // end namespace internal - -// Evenly Divided Histogram of a 1-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -histogram_even(Policy &&policy, Iter1 d_samples, Iter2 d_histogram, - int num_levels, T lower_level, T upper_level, Size num_samples) { - internal::dpl_histogram::histogram(::std::forward(policy), d_samples, - d_samples + num_samples, d_histogram, - num_levels - 1, lower_level, upper_level); -} - -// Evenly Divided Histogram of a 2-D ROI in a flattened 2-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -histogram_even_roi(Policy &&policy, Iter1 d_samples, Iter2 d_histogram, - int num_levels, T lower_level, T upper_level, - OffsetT num_row_samples, OffsetT num_rows, - ::std::size_t row_stride_bytes) { - return histogram_even( - ::std::forward(policy), - oneapi::dpl::permutation_iterator( - d_samples, - internal::__roi_2d_index_functor( - num_row_samples, - row_stride_bytes / - sizeof(typename ::std::iterator_traits::value_type))), - d_histogram, num_levels, lower_level, upper_level, - num_row_samples * num_rows); -} - -// Evenly Divided Multi-Channel Histogram of a 1-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -multi_histogram_even(Policy &&policy, Iter1 d_samples, - Iter2 d_histogram[NumActiveChannels], - int num_levels[NumActiveChannels], - T lower_level[NumActiveChannels], - T upper_level[NumActiveChannels], Size num_pixels) { - for (int active_channel = 0; active_channel < NumActiveChannels; - active_channel++) { - histogram_even( - policy, - oneapi::dpl::permutation_iterator( - d_samples, - internal::__interleaved_index_functor(NumChannels, active_channel)), - d_histogram[active_channel], num_levels[active_channel], - lower_level[active_channel], upper_level[active_channel], num_pixels); - } -} - -// Evenly Divided Multi-Channel Histogram of a 2-D ROI in a flattened 2-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -multi_histogram_even_roi(Policy &&policy, Iter1 d_samples, - Iter2 d_histogram[NumActiveChannels], - int num_levels[NumActiveChannels], - T lower_level[NumActiveChannels], - T upper_level[NumActiveChannels], - OffsetT num_row_samples, OffsetT num_rows, - ::std::size_t row_stride_bytes) { - for (int active_channel = 0; active_channel < NumActiveChannels; - active_channel++) { - histogram_even( - policy, - oneapi::dpl::permutation_iterator( - d_samples, - internal::__composition_functor( - internal::__roi_2d_index_functor( - num_row_samples, - row_stride_bytes / - (NumChannels * sizeof(typename ::std::iterator_traits< - Iter1>::value_type))), - internal::__interleaved_index_functor(NumChannels, - active_channel))), - d_histogram[active_channel], num_levels[active_channel], - lower_level[active_channel], upper_level[active_channel], - num_row_samples * num_rows); - } -} - -// Custom Range Histogram of a 1-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -histogram_range(Policy &&policy, Iter1 d_samples, Iter2 d_histogram, - int num_levels, Iter3 d_levels, Size num_samples) { - internal::dpl_histogram::histogram(::std::forward(policy), d_samples, - d_samples + num_samples, d_histogram, - d_levels, d_levels + num_levels); -} - -// Custom Range Histogram of a 2-D ROI in a flattened 2-D Array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -histogram_range_roi(Policy &&policy, Iter1 d_samples, Iter2 d_histogram, - int num_levels, Iter3 d_levels, OffsetT num_row_samples, - OffsetT num_rows, ::std::size_t row_stride_bytes) { - return histogram_range( - ::std::forward(policy), - oneapi::dpl::permutation_iterator( - d_samples, - internal::__roi_2d_index_functor( - num_row_samples, - row_stride_bytes / - sizeof(typename ::std::iterator_traits::value_type))), - d_histogram, num_levels, d_levels, num_row_samples * num_rows); -} - -// Custom Range Multi-Channel Histogram of a 1-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -multi_histogram_range(Policy &&policy, Iter1 d_samples, - Iter2 d_histogram[NumActiveChannels], - int num_levels[NumActiveChannels], - Iter3 d_levels[NumActiveChannels], Size num_pixels) { - for (int active_channel = 0; active_channel < NumActiveChannels; - active_channel++) { - histogram_range(policy, - oneapi::dpl::permutation_iterator( - d_samples, internal::__interleaved_index_functor( - NumChannels, active_channel)), - d_histogram[active_channel], num_levels[active_channel], - d_levels[active_channel], num_pixels); - } -} - -// Custom Range Multi-Channel Histogram of a 2-D ROI in a flattened 2-D array -template -::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -multi_histogram_range_roi(Policy &&policy, Iter1 d_samples, - Iter2 d_histogram[NumActiveChannels], - int num_levels[NumActiveChannels], - Iter3 d_levels[NumActiveChannels], - OffsetT num_row_samples, OffsetT num_rows, - ::std::size_t row_stride_bytes) { - for (int active_channel = 0; active_channel < NumActiveChannels; - active_channel++) { - histogram_range( - policy, - oneapi::dpl::permutation_iterator( - d_samples, - internal::__composition_functor( - internal::__roi_2d_index_functor( - num_row_samples, - row_stride_bytes / - (NumChannels * sizeof(typename ::std::iterator_traits< - Iter1>::value_type))), - internal::__interleaved_index_functor(NumChannels, - active_channel))), - d_histogram[active_channel], num_levels[active_channel], - d_levels[active_channel], num_row_samples * num_rows); - } -} - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value> -sort_pairs(Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, bool descending, int begin_bit, - int end_bit) { - internal::sort_pairs_impl(std::forward(policy), keys_in, keys_out, - values_in, values_out, n, descending, begin_bit, - end_bit); -} - -template -inline void sort_pairs( - Policy &&policy, io_iterator_pair &keys, - io_iterator_pair &values, ::std::int64_t n, bool descending = false, - bool do_swap_iters = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - sort_pairs(::std::forward(policy), keys.first(), keys.second(), - values.first(), values.second(), n, descending, begin_bit, - end_bit); - if (do_swap_iters) { - keys.swap(); - values.swap(); - } -} - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value> -sort_keys(Policy &&policy, Iter1 keys_in, Iter2 keys_out, ::std::int64_t n, - bool descending, int begin_bit, int end_bit) { - using key_t_value_t = typename ::std::iterator_traits::value_type; - - int clipped_begin_bit = ::std::max(begin_bit, 0); - int clipped_end_bit = - ::std::min((::std::uint64_t)end_bit, sizeof(key_t_value_t) * 8); - int num_bytes = (clipped_end_bit - clipped_begin_bit - 1) / 8 + 1; - - auto transform_and_sort_f = [&](auto x) { - using T = typename ::std::decay_t; - internal::transform_and_sort( - ::std::forward(policy), keys_in, keys_out, n, descending, - clipped_begin_bit, clipped_end_bit); - }; - - if (clipped_end_bit - clipped_begin_bit == sizeof(key_t_value_t) * 8) { - internal::sort_only(::std::forward(policy), keys_in, keys_out, n, - descending); - } else if (num_bytes == 1) { - transform_and_sort_f.template operator()(0); - } else if (num_bytes == 2) { - transform_and_sort_f.template operator()(0); - } else if (num_bytes <= 4) { - transform_and_sort_f.template operator()(0); - } else // if (num_bytes <= 8) - { - transform_and_sort_f.template operator()<::std::uint64_t>(0); - } -} - -template -inline void sort_keys( - Policy &&policy, io_iterator_pair &keys, ::std::int64_t n, - bool descending = false, bool do_swap_iters = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - sort_keys(std::forward(policy), keys.first(), keys.second(), n, - descending, begin_bit, end_bit); - if (do_swap_iters) - keys.swap(); -} - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value> -segmented_sort_keys( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, ::std::int64_t n, - ::std::int64_t nsegments, Iter3 begin_offsets, Iter3 end_offsets, - bool descending = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - int compute_units = - policy.queue() - .get_device() - .template get_info(); - auto sg_sizes = policy.queue() - .get_device() - .template get_info(); - int subgroup_size = sg_sizes.empty() ? 1 : sg_sizes.back(); - // parallel for of serial sorts when we have sufficient number of segments for - // load balance when number of segments is large as compared to our target - // compute capability - if (nsegments > - compute_units * - (policy.queue().get_device().is_gpu() ? subgroup_size : 1)) { - dpct::internal::segmented_sort_keys_by_parallel_for_of_sorts( - ::std::forward(policy), keys_in, keys_out, n, nsegments, - begin_offsets, end_offsets, descending, begin_bit, end_bit); - } else - { - dpct::internal::segmented_sort_keys_by_parallel_sorts( - ::std::forward(policy), keys_in, keys_out, n, nsegments, - begin_offsets, end_offsets, descending, begin_bit, end_bit); - } -} - -template -inline void segmented_sort_keys( - Policy &&policy, io_iterator_pair &keys, ::std::int64_t n, - ::std::int64_t nsegments, Iter2 begin_offsets, Iter2 end_offsets, - bool descending = false, bool do_swap_iters = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - segmented_sort_keys(::std::forward(policy), keys.first(), - keys.second(), n, nsegments, begin_offsets, end_offsets, - descending, begin_bit, end_bit); - if (do_swap_iters) { - keys.swap(); - } -} - -template -inline ::std::enable_if_t::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value> -segmented_sort_pairs( - Policy &&policy, Iter1 keys_in, Iter2 keys_out, Iter3 values_in, - Iter4 values_out, ::std::int64_t n, ::std::int64_t nsegments, - Iter5 begin_offsets, Iter5 end_offsets, bool descending = false, - int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - int compute_units = - policy.queue() - .get_device() - .template get_info(); - auto sg_sizes = policy.queue() - .get_device() - .template get_info(); - int subgroup_size = sg_sizes.empty() ? 1 : sg_sizes.back(); - // parallel for of serial sorts when we have sufficient number of segments for - // load balance when number of segments is large as compared to our target - // compute capability - if (nsegments > - compute_units * - (policy.queue().get_device().is_gpu() ? subgroup_size : 1)) { - dpct::internal::segmented_sort_pairs_by_parallel_for_of_sorts( - ::std::forward(policy), keys_in, keys_out, values_in, - values_out, n, nsegments, begin_offsets, end_offsets, descending, - begin_bit, end_bit); - } else - { - dpct::internal::segmented_sort_pairs_by_parallel_sorts( - ::std::forward(policy), keys_in, keys_out, values_in, - values_out, n, nsegments, begin_offsets, end_offsets, descending, - begin_bit, end_bit); - } -} - -template -inline void segmented_sort_pairs( - Policy &&policy, io_iterator_pair &keys, - io_iterator_pair &values, ::std::int64_t n, ::std::int64_t nsegments, - Iter3 begin_offsets, Iter3 end_offsets, bool descending = false, - bool do_swap_iters = false, int begin_bit = 0, - int end_bit = sizeof(typename ::std::iterator_traits::value_type) * - 8) { - segmented_sort_pairs(std::forward(policy), keys.first(), - keys.second(), values.first(), values.second(), n, - nsegments, begin_offsets, end_offsets, descending, - begin_bit, end_bit); - if (do_swap_iters) { - keys.swap(); - values.swap(); - } -} - -template -inline void reduce_argmax(Policy &&policy, Iter1 input, Iter2 output, - ::std::size_t n) { - dpct::arg_index_input_iterator input_arg_idx(input); - auto ret = ::std::max_element( - policy, input_arg_idx, input_arg_idx + n, - [](const auto &a, const auto &b) { return (a.value < b.value); }); - ::std::copy(::std::forward(policy), ret, ret + 1, output); -} - -template -inline void reduce_argmin(Policy &&policy, Iter1 input, Iter2 output, - ::std::size_t n) { - dpct::arg_index_input_iterator input_arg_idx(input); - auto ret = ::std::min_element( - policy, input_arg_idx, input_arg_idx + n, - [](const auto &a, const auto &b) { return (a.value < b.value); }); - ::std::copy(::std::forward(policy), ret, ret + 1, output); -} - -template -inline ::std::pair equal_range(Policy &&policy, Iter1 start, - Iter1 end, const ValueT &value, - CompT comp) { - ::std::vector<::std::int64_t> res_lower(1); - ::std::vector<::std::int64_t> res_upper(1); - ::std::vector value_vec(1, value); - ::oneapi::dpl::lower_bound(policy, start, end, value_vec.begin(), - value_vec.end(), res_lower.begin(), comp); - ::oneapi::dpl::upper_bound(::std::forward(policy), start, end, - value_vec.begin(), value_vec.end(), - res_upper.begin(), comp); - return ::std::make_pair(start + res_lower[0], start + res_upper[0]); -} - -template -inline ::std::pair equal_range(Policy &&policy, Iter1 start, - Iter1 end, const ValueT &value) { - return equal_range(::std::forward(policy), start, end, value, - internal::__less()); -} - -template -inline ::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -segmented_reduce_argmin(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - ::std::int64_t nsegments, Iter3 begin_offsets, - Iter3 end_offsets) { - policy.queue().submit([&](sycl::handler &cgh) { - cgh.parallel_for(nsegments, [=](sycl::id<1> i) { - if (end_offsets[i] <= begin_offsets[i]) { - keys_out[i] = dpct::key_value_pair( - 1, ::std::numeric_limits< - typename ::std::iterator_traits::value_type>::max()); - } else { - dpct::arg_index_input_iterator arg_index(keys_in + - begin_offsets[i]); - keys_out[i] = *::std::min_element( - arg_index, arg_index + (end_offsets[i] - begin_offsets[i]), - [](const auto &a, const auto &b) { return a.value < b.value; }); - } - }); - }); - policy.queue().wait(); -} - -template -inline ::std::enable_if_t< - dpct::internal::is_iterator::value && - dpct::internal::is_iterator::value && - internal::is_hetero_execution_policy<::std::decay_t>::value> -segmented_reduce_argmax(Policy &&policy, Iter1 keys_in, Iter2 keys_out, - ::std::int64_t nsegments, Iter3 begin_offsets, - Iter3 end_offsets) { - policy.queue().submit([&](sycl::handler &cgh) { - cgh.parallel_for(nsegments, [=](sycl::id<1> i) { - if (end_offsets[i] <= begin_offsets[i]) { - keys_out[i] = dpct::key_value_pair( - 1, - ::std::numeric_limits< - typename ::std::iterator_traits::value_type>::lowest()); - } else { - dpct::arg_index_input_iterator arg_index(keys_in + - begin_offsets[i]); - keys_out[i] = *::std::max_element( - arg_index, arg_index + (end_offsets[i] - begin_offsets[i]), - [](const auto &a, const auto &b) { return a.value < b.value; }); - } - }); - }); - policy.queue().wait(); -} - -template -void nontrivial_run_length_encode(ExecutionPolicy &&policy, - InputIterator input_beg, - OutputIterator1 offsets_out, - OutputIterator2 lengths_out, - OutputIterator3 num_runs, - ::std::int64_t num_items) { - using oneapi::dpl::make_transform_iterator; - using oneapi::dpl::make_zip_iterator; - using offsets_t = - typename ::std::iterator_traits::value_type; - using lengths_t = - typename ::std::iterator_traits::value_type; - - auto input_end = input_beg + num_items; - // First element must be nontrivial run (start of first segment) - auto first_adj_it = oneapi::dpl::adjacent_find(policy, input_beg, input_end); - auto first_adj_idx = ::std::distance(input_beg, first_adj_it); - if (first_adj_it == input_end) { - ::std::fill(policy, num_runs, num_runs + 1, 0); - return; - } - auto get_prev_idx_element = [first_adj_idx](const auto &idx) { - auto out_idx = idx + first_adj_idx; - return (out_idx == 0) ? 0 : out_idx - 1; - }; - auto get_next_idx_element = [first_adj_idx, num_items](const auto &idx) { - auto out_idx = idx + first_adj_idx; - return (out_idx == num_items - 1) ? num_items - 1 : out_idx + 1; - }; - // TODO: Use shifted view to pad range once oneDPL ranges is non-experimental - auto left_shifted_input_beg = - oneapi::dpl::make_permutation_iterator(input_beg, get_prev_idx_element); - auto right_shifted_input_beg = - oneapi::dpl::make_permutation_iterator(input_beg, get_next_idx_element); - // Segment type for ith idx consists of zip of iterators at (i-1, i, i+1) - // padded at the ends - auto zipped_keys_beg = make_zip_iterator( - left_shifted_input_beg, input_beg, right_shifted_input_beg, - oneapi::dpl::counting_iterator(0)); - // Set flag at the beginning of new nontrivial run (ex: (2, 3, 3) -> 1) - auto key_flags_beg = - make_transform_iterator(zipped_keys_beg, [num_items](const auto &zipped) { - using ::std::get; - bool last_idx_mask = get<3>(zipped) != num_items - 1; - return (get<0>(zipped) != get<1>(zipped) && - get<1>(zipped) == get<2>(zipped)) && - last_idx_mask; - }); - auto count_beg = oneapi::dpl::counting_iterator(0); - auto const_it = dpct::make_constant_iterator(lengths_t(1)); - // Check for presence of nontrivial element at current index - auto tr_nontrivial_flags = make_transform_iterator( - make_zip_iterator(left_shifted_input_beg, input_beg), - [](const auto &zip) { - using ::std::get; - return get<0>(zip) == get<1>(zip); - }); - auto zipped_vals_beg = - make_zip_iterator(tr_nontrivial_flags, count_beg, const_it); - auto pred = [](bool lhs, bool rhs) { return !rhs; }; - auto op = [](auto lhs, const auto &rhs) { - using ::std::get; - - // Update length count of run. - // The first call to this op will use the first element of the input as lhs - // and second element as rhs. get<0>(first_element) is ignored in favor of a - // constant `1` in get<2>, avoiding the need for special casing the first - // element. The constant `1` utilizes the knowledge that each segment begins - // with a nontrivial run. - get<2>(lhs) += get<0>(rhs); - - // A run's starting index is stored in get<1>(lhs) as the initial value in - // the segment and is preserved throughout the segment's reduction as the - // nontrivial run's offset. - - return ::std::move(lhs); - }; - auto zipped_out_beg = make_zip_iterator(oneapi::dpl::discard_iterator(), - offsets_out, lengths_out); - auto [_, zipped_out_vals_end] = oneapi::dpl::reduce_by_segment( - policy, key_flags_beg + first_adj_idx, key_flags_beg + num_items, - zipped_vals_beg + first_adj_idx, oneapi::dpl::discard_iterator(), - zipped_out_beg, pred, op); - auto ret_dist = ::std::distance(zipped_out_beg, zipped_out_vals_end); - ::std::fill(policy, num_runs, num_runs + 1, ret_dist); -} - -} // end namespace dpct - -#endif diff --git a/dpct/dpl_extras/dpcpp_extensions.h b/dpct/dpl_extras/dpcpp_extensions.h deleted file mode 100644 index 05a0068e6..000000000 --- a/dpct/dpl_extras/dpcpp_extensions.h +++ /dev/null @@ -1,747 +0,0 @@ -//==---- dpcpp_extensions.h ------------------*- C++ -*---------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------===// - -#ifndef __DPCT_DPCPP_EXTENSIONS_H__ -#define __DPCT_DPCPP_EXTENSIONS_H__ - -#include -#include - -#ifdef SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS -#include -#endif - -#include "../dpct.hpp" -#include "functional.h" - -namespace dpct { -namespace group { -namespace detail { - -template -constexpr auto __reduce_over_group(_Args... __args) { - return sycl::reduce_over_group(__args...); -} - -template constexpr auto __group_broadcast(_Args... __args) { - return sycl::group_broadcast(__args...); -} - -template -constexpr auto __exclusive_scan_over_group(_Args... __args) { - return sycl::exclusive_scan_over_group(__args...); -} - -template -constexpr auto __inclusive_scan_over_group(_Args... __args) { - return sycl::inclusive_scan_over_group(__args...); -} - -} // end namespace detail - -/// Perform an exclusive scan over the values of inputs from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param inputs Pointer to the input data for the scan operation. -/// \param outputs Pointer to the location where scan results will be stored. -/// \param init initial value of the scan result. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. -template -__dpct_inline__ void -exclusive_scan(const Item &item, T (&inputs)[VALUES_PER_THREAD], - T (&outputs)[VALUES_PER_THREAD], T init, - BinaryOperation binary_op) { - T result = inputs[0]; - -#pragma unroll - for (int i = 1; i < VALUES_PER_THREAD; ++i) { - result = binary_op(result, inputs[i]); - } - - T exclusive_result = - detail::__exclusive_scan_over_group(item.get_group(), result, binary_op); - - T input = inputs[0]; - if (item.get_local_linear_id() == 0) { - outputs[0] = init; - } else { - outputs[0] = exclusive_result; - } - -#pragma unroll - for (int i = 1; i < VALUES_PER_THREAD; ++i) { - T output = binary_op(input, outputs[i - 1]); - input = inputs[i]; - outputs[i] = output; - } -} - -/// Perform an exclusive scan over the values of input from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param input Input data for the scan operation. -/// \param init initial value of the scan result. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \param group_aggregate group-wide aggregate of all inputs -/// in the work-items of the group. \returns exclusive scan of the first i -/// work-items where item is the i-th work item. -template -__dpct_inline__ T -exclusive_scan(const Item &item, T input, T init, BinaryOperation binary_op, - T &group_aggregate) { - T output = detail::__exclusive_scan_over_group(item.get_group(), input, init, - binary_op); - if (item.get_local_linear_id() == item.get_local_range().size() - 1) { - group_aggregate = binary_op(output, input); - } - - group_aggregate = detail::__group_broadcast( - item.get_group(), group_aggregate, item.get_local_range().size() - 1); - return output; -} - -/// Perform an exclusive scan over the values of input from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param input Input data for the scan operation. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \param prefix_callback_op functor invoked by the first -/// work-item in the group that returns the -/// initial value in the resulting scan of the work-items in the group. -/// \returns exclusive scan of the input elements assigned to work-items in the -/// group. -template -__dpct_inline__ T -exclusive_scan(const Item &item, T input, BinaryOperation binary_op, - GroupPrefixCallbackOperation &prefix_callback_op) { - T group_aggregate; - - T output = - detail::__exclusive_scan_over_group(item.get_group(), input, binary_op); - if (item.get_local_linear_id() == item.get_local_range().size() - 1) { - group_aggregate = binary_op(output, input); - } - - group_aggregate = detail::__group_broadcast( - item.get_group(), group_aggregate, item.get_local_range().size() - 1); - - T group_prefix = prefix_callback_op(group_aggregate); - if (item.get_local_linear_id() == 0) { - output = group_prefix; - } else { - output = binary_op(group_prefix, output); - } - - return output; -} - -namespace detail { - -typedef uint16_t digit_counter_type; -typedef uint32_t packed_counter_type; - -template struct log2 { - enum { VALUE = log2> 1), COUNT + 1>::VALUE }; -}; - -template struct log2 { - enum { VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1 }; -}; - -template class radix_rank { -public: - static size_t get_local_memory_size(size_t group_threads) { - return group_threads * PADDED_COUNTER_LANES * sizeof(packed_counter_type); - } - - radix_rank(uint8_t *local_memory) : _local_memory(local_memory) {} - - template - __dpct_inline__ void - rank_keys(const Item &item, uint32_t (&keys)[VALUES_PER_THREAD], - int (&ranks)[VALUES_PER_THREAD], int current_bit, int num_bits) { - - digit_counter_type thread_prefixes[VALUES_PER_THREAD]; - digit_counter_type *digit_counters[VALUES_PER_THREAD]; - digit_counter_type *buffer = - reinterpret_cast(_local_memory); - - reset_local_memory(item); - - item.barrier(sycl::access::fence_space::local_space); - -#pragma unroll - for (int i = 0; i < VALUES_PER_THREAD; ++i) { - uint32_t digit = ::dpct::bfe(keys[i], current_bit, num_bits); - uint32_t sub_counter = digit >> LOG_COUNTER_LANES; - uint32_t counter_lane = digit & (COUNTER_LANES - 1); - - if (DESCENDING) { - sub_counter = PACKING_RATIO - 1 - sub_counter; - counter_lane = COUNTER_LANES - 1 - counter_lane; - } - - digit_counters[i] = - &buffer[counter_lane * item.get_local_range().size() * PACKING_RATIO + - item.get_local_linear_id() * PACKING_RATIO + sub_counter]; - thread_prefixes[i] = *digit_counters[i]; - *digit_counters[i] = thread_prefixes[i] + 1; - } - - item.barrier(sycl::access::fence_space::local_space); - - scan_counters(item); - - item.barrier(sycl::access::fence_space::local_space); - - for (int i = 0; i < VALUES_PER_THREAD; ++i) { - ranks[i] = thread_prefixes[i] + *digit_counters[i]; - } - } - -private: - template - __dpct_inline__ void reset_local_memory(const Item &item) { - packed_counter_type *ptr = - reinterpret_cast(_local_memory); - -#pragma unroll - for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { - ptr[i * item.get_local_range().size() + item.get_local_linear_id()] = 0; - } - } - - template - __dpct_inline__ packed_counter_type upsweep(const Item &item) { - packed_counter_type sum = 0; - packed_counter_type *ptr = - reinterpret_cast(_local_memory); - -#pragma unroll - for (int i = 0; i < PADDED_COUNTER_LANES; i++) { - cached_segment[i] = - ptr[item.get_local_linear_id() * PADDED_COUNTER_LANES + i]; - } - -#pragma unroll - for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { - sum += cached_segment[i]; - } - - return sum; - } - - template - __dpct_inline__ void - exclusive_downsweep(const Item &item, packed_counter_type raking_partial) { - packed_counter_type *ptr = - reinterpret_cast(_local_memory); - packed_counter_type sum = raking_partial; - -#pragma unroll - for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { - packed_counter_type value = cached_segment[i]; - cached_segment[i] = sum; - sum += value; - } - -#pragma unroll - for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { - ptr[item.get_local_linear_id() * PADDED_COUNTER_LANES + i] = - cached_segment[i]; - } - } - - struct prefix_callback { - __dpct_inline__ packed_counter_type - operator()(packed_counter_type block_aggregate) { - packed_counter_type block_prefix = 0; - -#pragma unroll - for (int packed = 1; packed < PACKING_RATIO; packed++) { - block_prefix += block_aggregate - << (sizeof(digit_counter_type) * 8 * packed); - } - - return block_prefix; - } - }; - - template - __dpct_inline__ void scan_counters(const Item &item) { - packed_counter_type raking_partial = upsweep(item); - - prefix_callback callback; - packed_counter_type exclusive_partial = exclusive_scan( - item, raking_partial, sycl::ext::oneapi::plus(), - callback); - - exclusive_downsweep(item, exclusive_partial); - } - -private: - static constexpr int PACKING_RATIO = - sizeof(packed_counter_type) / sizeof(digit_counter_type); - static constexpr int LOG_PACKING_RATIO = log2::VALUE; - static constexpr int LOG_COUNTER_LANES = RADIX_BITS - LOG_PACKING_RATIO; - static constexpr int COUNTER_LANES = 1 << LOG_COUNTER_LANES; - static constexpr int PADDED_COUNTER_LANES = COUNTER_LANES + 1; - - packed_counter_type cached_segment[PADDED_COUNTER_LANES]; - uint8_t *_local_memory; -}; - -template struct base_traits { - - static __dpct_inline__ U twiddle_in(U key) { - throw std::runtime_error("Not implemented"); - } - static __dpct_inline__ U twiddle_out(U key) { - throw std::runtime_error("Not implemented"); - } -}; - -template struct base_traits { - static __dpct_inline__ U twiddle_in(U key) { return key; } - static __dpct_inline__ U twiddle_out(U key) { return key; } -}; - -template struct base_traits { - static constexpr U HIGH_BIT = U(1) << ((sizeof(U) * 8) - 1); - static __dpct_inline__ U twiddle_in(U key) { return key ^ HIGH_BIT; } - static __dpct_inline__ U twiddle_out(U key) { return key ^ HIGH_BIT; } -}; - -template struct base_traits { - static constexpr U HIGH_BIT = U(1) << ((sizeof(U) * 8) - 1); - static __dpct_inline__ U twiddle_in(U key) { - U mask = (key & HIGH_BIT) ? U(-1) : HIGH_BIT; - return key ^ mask; - } - static __dpct_inline__ U twiddle_out(U key) { - U mask = (key & HIGH_BIT) ? HIGH_BIT : U(-1); - return key ^ mask; - } -}; - -template struct traits : base_traits {}; -template <> struct traits : base_traits {}; -template <> struct traits : base_traits {}; -template <> struct traits : base_traits {}; - -} // namespace detail - -namespace detail { - -template struct power_of_two { - enum { VALUE = ((N & (N - 1)) == 0) }; -}; - -__dpct_inline__ uint32_t shr_add(uint32_t x, uint32_t shift, uint32_t addend) { - return (x >> shift) + addend; -} - -} // namespace detail - -/// Implements scatter to blocked exchange pattern used in radix sort algorithm. -/// -/// \tparam T type of the data elements exchanges -/// \tparam VALUES_PER_THREAD number of data elements assigned to a thread -template class exchange { -public: - static size_t get_local_memory_size(size_t group_threads) { - size_t padding_values = - (INSERT_PADDING) - ? ((group_threads * VALUES_PER_THREAD) >> LOG_LOCAL_MEMORY_BANKS) - : 0; - return (group_threads * VALUES_PER_THREAD + padding_values) * sizeof(T); - } - - exchange(uint8_t *local_memory) : _local_memory(local_memory) {} - - /// Rearrange elements from rank order to blocked order - template - __dpct_inline__ void - scatter_to_blocked(Item item, T (&keys)[VALUES_PER_THREAD], - int (&ranks)[VALUES_PER_THREAD]) { - T *buffer = reinterpret_cast(_local_memory); - -#pragma unroll - for (int i = 0; i < VALUES_PER_THREAD; i++) { - int offset = ranks[i]; - if (INSERT_PADDING) - offset = detail::shr_add(offset, LOG_LOCAL_MEMORY_BANKS, offset); - buffer[offset] = keys[i]; - } - - item.barrier(sycl::access::fence_space::local_space); - -#pragma unroll - for (int i = 0; i < VALUES_PER_THREAD; i++) { - int offset = (item.get_local_id(0) * VALUES_PER_THREAD) + i; - if (INSERT_PADDING) - offset = detail::shr_add(offset, LOG_LOCAL_MEMORY_BANKS, offset); - keys[i] = buffer[offset]; - } - } - -private: - static constexpr int LOG_LOCAL_MEMORY_BANKS = 5; - static constexpr bool INSERT_PADDING = - (VALUES_PER_THREAD > 4) && - (detail::power_of_two::VALUE); - - uint8_t *_local_memory; -}; - -/// Implements radix sort to sort integer data elements assigned to all threads -/// in the group. -/// -/// \tparam T type of the data elements exchanges -/// \tparam VALUES_PER_THREAD number of data elements assigned to a thread -/// \tparam DECENDING boolean value indicating if data elements are sorted in -/// decending order. -template -class radix_sort { -public: - static size_t get_local_memory_size(size_t group_threads) { - size_t ranks_size = - detail::radix_rank::get_local_memory_size(group_threads); - size_t exchange_size = - exchange::get_local_memory_size(group_threads); - return sycl::max(ranks_size, exchange_size); - } - - radix_sort(uint8_t *local_memory) : _local_memory(local_memory) {} - - template - __dpct_inline__ void - sort(const Item &item, T (&keys)[VALUES_PER_THREAD], int begin_bit = 0, - int end_bit = 8 * sizeof(T)) { - - uint32_t(&unsigned_keys)[VALUES_PER_THREAD] = - reinterpret_cast(keys); - -#pragma unroll - for (int i = 0; i < VALUES_PER_THREAD; ++i) { - unsigned_keys[i] = detail::traits::twiddle_in(unsigned_keys[i]); - } - - while (true) { - int pass_bits = sycl::min(RADIX_BITS, end_bit - begin_bit); - - int ranks[VALUES_PER_THREAD]; - detail::radix_rank(_local_memory) - .template rank_keys(item, unsigned_keys, ranks, begin_bit, pass_bits); - begin_bit += RADIX_BITS; - - item.barrier(sycl::access::fence_space::local_space); - - exchange(_local_memory) - .scatter_to_blocked(item, keys, ranks); - - item.barrier(sycl::access::fence_space::local_space); - - if (begin_bit >= end_bit) - break; - } - -#pragma unroll - for (int i = 0; i < VALUES_PER_THREAD; ++i) { - unsigned_keys[i] = detail::traits::twiddle_out(unsigned_keys[i]); - } - } - -private: - static constexpr int RADIX_BITS = 4; - - uint8_t *_local_memory; -}; - -/// Perform a reduction of the data elements assigned to all threads in the -/// group. -/// -/// \param item A work-item in a group. -/// \param inputs Pointer to the input data for the reduce operation. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \returns value of the reduction using binary_op -template -__dpct_inline__ T -reduce(Item item, T (&inputs)[VALUES_PER_THREAD], BinaryOperation binary_op) { - T result = inputs[0]; - -#pragma unroll - for (int i = 1; i < VALUES_PER_THREAD; i++) { - result = binary_op(result, inputs[i]); - } - return detail::__reduce_over_group(item.get_group(), result, binary_op); -} - -/// Perform a reduction on a limited number of the work items in a subgroup -/// -/// \param item A work-item in a group. -/// \param value value per work item which is to be reduced -/// \param items_to_reduce num work items at the start of the subgroup to reduce -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \returns value of the reduction using binary_op -template -__dpct_inline__ -typename ::std::enable_if_t, T> -reduce_over_partial_group(const Item &item, const T &value, - const ::std::uint16_t &items_to_reduce, - BinaryOperation binary_op) { - T value_temp = (item.get_local_linear_id() < items_to_reduce) - ? value - : sycl::known_identity_v; - return detail::__reduce_over_group(item.get_sub_group(), value_temp, - binary_op); -} - -/// Perform an inclusive scan over the values of inputs from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param inputs Pointer to the input data for the scan operation. -/// \param outputs Pointer to the location where scan results will be stored. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \returns inclusive scan of the input elements assigned to -/// work-items in the group. -template -__dpct_inline__ void -inclusive_scan(const Item &item, T (&inputs)[VALUES_PER_THREAD], - T (&outputs)[VALUES_PER_THREAD], BinaryOperation binary_op) { - T result = inputs[0]; - -#pragma unroll - for (int i = 1; i < VALUES_PER_THREAD; ++i) { - result = binary_op(result, inputs[i]); - } - - T exclusive_result = - detail::__exclusive_scan_over_group(item.get_group(), result, binary_op); - - if (item.get_local_linear_id() == 0) { - outputs[0] = inputs[0]; - } else { - outputs[0] = binary_op(inputs[0], exclusive_result); - } - -#pragma unroll - for (int i = 1; i < VALUES_PER_THREAD; ++i) { - outputs[i] = binary_op(inputs[i], outputs[i - 1]); - } -} - -/// Perform an inclusive scan over the values of inputs from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param input Pointer to the input data for the scan operation. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \param group_aggregate group-wide aggregate of all inputs -/// in the work-items of the group. \returns inclusive scan of the input -/// elements assigned to work-items in the group. -template -__dpct_inline__ T inclusive_scan(const Item &item, T input, - BinaryOperation binary_op, - T &group_aggregate) { - T output = - detail::__inclusive_scan_over_group(item.get_group(), input, binary_op); - if (item.get_local_linear_id() == item.get_local_range().size() - 1) { - group_aggregate = output; - } - - group_aggregate = detail::__group_broadcast( - item.get_group(), group_aggregate, item.get_local_range().size() - 1); - return output; -} - -/// Perform an inclusive scan over the values of input from all work-items in -/// the group using the operator binary_op, which must be one of the SYCL 2020 -/// group algorithms library function objects. -/// -/// \param item A work-item in a group. -/// \param input Input data for the scan operation. -/// \param binary_op functor that implements the binary operation used to -/// perform the scan. \param prefix_callback_op functor invoked by the first -/// work-item in the group that returns the -/// initial value in the resulting scan of the work-items in the group. -/// \returns inclusive scan of the input elements assigned to work-items in the -/// group. -template -__dpct_inline__ T -inclusive_scan(const Item &item, T input, BinaryOperation binary_op, - GroupPrefixCallbackOperation &prefix_callback_op) { - T group_aggregate; - - T output = inclusive_scan(item, input, binary_op, group_aggregate); - T group_prefix = prefix_callback_op(group_aggregate); - - return binary_op(group_prefix, output); -} - -} // namespace group - -namespace device { - -namespace detail { - -template constexpr auto __joint_reduce(_Args... __args) { - return sycl::joint_reduce(__args...); -} - -} // namespace detail - -/// Perform a reduce on each of the segments specified within data stored on -/// the device. -/// -/// \param queue Command queue used to access device used for reduction -/// \param inputs Pointer to the data elements on the device to be reduced -/// \param outputs Pointer to the storage where the reduced value for each -/// segment will be stored \param segment_count number of segments to be reduced -/// \param begin_offsets Pointer to the set of indices that are the first -/// element in each segment \param end_offsets Pointer to the set of indices -/// that are one past the last element in each segment \param binary_op functor -/// that implements the binary operation used to perform the scan. \param init -/// initial value of the reduction for each segment. -template -void segmented_reduce(sycl::queue queue, T *inputs, T *outputs, - size_t segment_count, OffsetT *begin_offsets, - OffsetT *end_offsets, BinaryOperation binary_op, T init) { - - sycl::range<1> global_size(segment_count * GROUP_SIZE); - sycl::range<1> local_size(GROUP_SIZE); - - queue.submit([&](sycl::handler &cgh) { - cgh.parallel_for( - sycl::nd_range<1>(global_size, local_size), [=](sycl::nd_item<1> item) { - OffsetT segment_begin = begin_offsets[item.get_group_linear_id()]; - OffsetT segment_end = end_offsets[item.get_group_linear_id()]; - if (segment_begin == segment_end) { - if (item.get_local_linear_id() == 0) { - outputs[item.get_group_linear_id()] = init; - } - return; - } - - sycl::multi_ptr - input_ptr = inputs; - T group_aggregate = detail::__joint_reduce( - item.get_group(), input_ptr + segment_begin, - input_ptr + segment_end, init, binary_op); - - if (item.get_local_linear_id() == 0) { - outputs[item.get_group_linear_id()] = group_aggregate; - } - }); - }); -} - - -#ifdef SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS - -namespace experimental { -namespace detail { -template struct __is_any { - constexpr static bool value = std::disjunction_v< - std::is_same, std::remove_cv_t<_Ts>>...>; -}; - -template struct __in_native_op_list { - constexpr static bool value = - __is_any<_Bp, sycl::plus<_Tp>, sycl::bit_or<_Tp>, sycl::bit_xor<_Tp>, - sycl::bit_and<_Tp>, sycl::maximum<_Tp>, sycl::minimum<_Tp>, - sycl::multiplies<_Tp>>::value; -}; - -template struct __is_native_op { - constexpr static bool value = __in_native_op_list<_Tp, _Bp>::value || - __in_native_op_list::value; -}; - -} // namespace detail - -/// Perform a reduce on each of the segments specified within data stored on -/// the device. Compared with dpct::device::segmented_reduce, this experimental -/// feature support user define reductions. -/// -/// \param queue Command queue used to access device used for reduction -/// \param inputs Pointer to the data elements on the device to be reduced -/// \param outputs Pointer to the storage where the reduced value for each -/// segment will be stored \param segment_count number of segments to be reduced -/// \param begin_offsets Pointer to the set of indices that are the first -/// element in each segment \param end_offsets Pointer to the set of indices -/// that are one past the last element in each segment \param binary_op functor -/// that implements the binary operation used to perform the scan. \param init -/// initial value of the reduction for each segment. -template -void segmented_reduce(sycl::queue queue, T *inputs, T *outputs, - size_t segment_count, OffsetT *begin_offsets, - OffsetT *end_offsets, BinaryOperation binary_op, T init) { - - sycl::range<1> global_size(segment_count * GROUP_SIZE); - sycl::range<1> local_size(GROUP_SIZE); - - if constexpr (!detail::__is_native_op::value) { - queue.submit([&](sycl::handler &cgh) { - size_t temp_memory_size = GROUP_SIZE * sizeof(T); - auto scratch = sycl::local_accessor(temp_memory_size, cgh); - cgh.parallel_for( - sycl::nd_range<1>(global_size, local_size), - [=](sycl::nd_item<1> item) { - OffsetT segment_begin = begin_offsets[item.get_group_linear_id()]; - OffsetT segment_end = end_offsets[item.get_group_linear_id()]; - if (segment_begin == segment_end) { - if (item.get_local_linear_id() == 0) { - outputs[item.get_group_linear_id()] = init; - } - return; - } - // Create a handle that associates the group with an allocation it - // can use - auto handle = - sycl::ext::oneapi::experimental::group_with_scratchpad( - item.get_group(), - sycl::span(&scratch[0], temp_memory_size)); - T group_aggregate = sycl::ext::oneapi::experimental::joint_reduce( - handle, inputs + segment_begin, inputs + segment_end, init, - binary_op); - if (item.get_local_linear_id() == 0) { - outputs[item.get_group_linear_id()] = group_aggregate; - } - }); - }); - } else { - dpct::device::segmented_reduce(queue, inputs, outputs, - segment_count, begin_offsets, - end_offsets, binary_op, init); - } -} -} // namespace experimental - -#endif // SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS - - -} // namespace device -} // namespace dpct - -#endif diff --git a/dpct/dpl_extras/functional.h b/dpct/dpl_extras/functional.h deleted file mode 100644 index bab82814c..000000000 --- a/dpct/dpl_extras/functional.h +++ /dev/null @@ -1,453 +0,0 @@ -//==---- functional.h -----------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_FUNCTIONAL_H__ -#define __DPCT_FUNCTIONAL_H__ - -#include -#include -#include - -#if ONEDPL_USE_DPCPP_BACKEND -#include -#endif - -#include -#include - -#include "../dpct.hpp" -#define _DPCT_GCC_VERSION \ - (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) - -// Portability "#pragma" definition -#ifdef _MSC_VER -#define _DPCT_PRAGMA(x) __pragma(x) -#else -#define _DPCT_PRAGMA(x) _Pragma(#x) -#endif - -// Enable loop unrolling pragmas where supported -#if (__INTEL_COMPILER || \ - (!defined(__INTEL_COMPILER) && _DPCT_GCC_VERSION >= 80000)) -#define _DPCT_PRAGMA_UNROLL _DPCT_PRAGMA(unroll) -#else // no pragma unroll -#define _DPCT_PRAGMA_UNROLL -#endif - -namespace dpct { - -struct null_type {}; - -// Function object to wrap user defined functors to provide compile time "const" -// workaround for user function objects. -// The SYCL spec (4.12) states that writing to a function object during a SYCL -// kernel is undefined behavior. This wrapper is provided as a compile-time -// work around, but functors used in SYCL kernels must be `const` in practice. -template struct mark_functor_const { - mutable _Op op; - mark_functor_const() : op() {} - mark_functor_const(const _Op &__op) : op(__op) {} - mark_functor_const(_Op &&__op) : op(::std::move(__op)) {} - template auto operator()(_T &&...x) const { - return op(std::forward<_T>(x)...); - } -}; - -namespace internal { - -template -using enable_if_execution_policy = - typename std::enable_if::type>::value, - _T>::type; - -template struct is_hetero_execution_policy : ::std::false_type {}; - -template -struct is_hetero_execution_policy< - oneapi::dpl::execution::device_policy> : ::std::true_type { -}; - -template struct is_fpga_execution_policy : ::std::false_type {}; - -#if _ONEDPL_FPGA_DEVICE -template -struct is_hetero_execution_policy< - execution::fpga_policy> : ::std::true_type { -}; -#endif - -template -using enable_if_hetero_execution_policy = typename std::enable_if< - is_hetero_execution_policy::type>::value, - _T>::type; - -#if _ONEDPL_CPP14_INTEGER_SEQUENCE_PRESENT - -template -using index_sequence = std::index_sequence<_Sp...>; -template -using make_index_sequence = std::make_index_sequence<_Np>; - -#else - -template class index_sequence {}; - -template -struct make_index_sequence_impl - : make_index_sequence_impl<_Np - 1, _Np - 1, _Sp...> {}; - -template struct make_index_sequence_impl<0, _Sp...> { - using type = index_sequence<_Sp...>; -}; - -template -using make_index_sequence = typename make_index_sequence_impl<_Np>::type; -#endif - -// Minimal buffer implementations for temporary storage in mapping rules -// Some of our algorithms need to start with raw memory buffer, -// not an initialized array, because initialization/destruction -// would make the span be at least O(N). -#if ONEDPL_USE_DPCPP_BACKEND -template class __buffer { - sycl::buffer<_Tp, 1> __buf; - - __buffer(const __buffer &) = delete; - - void operator=(const __buffer &) = delete; - -public: - // Try to obtain buffer of given size to store objects of _Tp type - __buffer(std::size_t __n) : __buf(sycl::range<1>(__n)) {} - - // Return pointer to buffer, or NULL if buffer could not be obtained. - auto get() -> decltype(oneapi::dpl::begin(__buf)) const { - return oneapi::dpl::begin(__buf); - } -}; -#else -template class __buffer { - std::unique_ptr<_Tp> _M_ptr; - - __buffer(const __buffer &) = delete; - - void operator=(const __buffer &) = delete; - -public: - // Try to obtain buffer of given size to store objects of _Tp type - __buffer(const std::size_t __n) : _M_ptr(new _Tp[__n]) {} - - // Return pointer to buffer, or NULL if buffer could not be obtained. - _Tp *get() const { return _M_ptr.get(); } -}; -#endif - -// Implements C++14 std::less specialization to allow parameter type -// deduction. -class __less { -public: - template - bool operator()(_Xp &&__x, _Yp &&__y) const { - return std::forward<_Xp>(__x) < std::forward<_Yp>(__y); - } -}; - -template struct rebind_policy { - using type = Policy; -}; - -template -struct rebind_policy, - NewName> { - using type = oneapi::dpl::execution::device_policy; -}; - -#if _ONEDPL_FPGA_DEVICE -template -struct rebind_policy, - NewName> { - using type = oneapi::dpl::execution::fpga_policy; -}; -#endif - -template ::reference, - typename R2 = typename std::iterator_traits::reference> -struct perm_fun { - typedef R2 result_of; - perm_fun(T1 input) : source(input) {} - - R2 operator()(R1 x) const { return *(source + x); } - -private: - T1 source; -}; - -// Functor compares first element (key) from tied sequence. -template struct compare_key_fun { - typedef bool result_of; - compare_key_fun(Compare _comp = internal::__less()) : comp(_comp) {} - - template - result_of operator()(_T1 &&a, _T2 &&b) const { - using std::get; - return comp(get<0>(a), get<0>(b)); - } - -private: - mutable Compare comp; -}; - -// Functor evaluates second element of tied sequence with predicate. -// Used by: copy_if, remove_copy_if, stable_partition_copy -// Lambda: -template struct predicate_key_fun { - typedef bool result_of; - predicate_key_fun(Predicate _pred) : pred(_pred) {} - - template result_of operator()(_T1 &&a) const { - using std::get; - return pred(get<1>(a)); - } - -private: - mutable Predicate pred; -}; - -// Used by: remove_if -template struct negate_predicate_key_fun { - typedef bool result_of; - negate_predicate_key_fun(Predicate _pred) : pred(_pred) {} - - template result_of operator()(_T1 &&a) const { - using std::get; - return !pred(get<1>(a)); - } - -private: - mutable Predicate pred; -}; - -template struct sequence_fun { - using result_type = T; - sequence_fun(T _init, T _step) : init(_init), step(_step) {} - - template result_type operator()(_T &&i) const { - return static_cast(init + step * i); - } - -private: - const T init; - const T step; -}; - -//[binary_pred](Ref a, Ref b){ return(binary_pred(get<0>(a),get<0>(b))); -template struct unique_fun { - typedef bool result_of; - unique_fun(Predicate _pred) : pred(_pred) {} - template result_of operator()(_T &&a, _T &&b) const { - using std::get; - return pred(get<0>(a), get<0>(b)); - } - -private: - mutable Predicate pred; -}; - -// Lambda: [pred, &new_value](Ref1 a, Ref2 s) {return pred(s) ? new_value : a; -// }); -template struct replace_if_fun { -public: - typedef T result_of; - replace_if_fun(Predicate _pred, T _new_value) - : pred(_pred), new_value(_new_value) {} - - template T operator()(_T1 &&a, _T2 &&s) const { - return pred(s) ? new_value : a; - } - -private: - mutable Predicate pred; - const T new_value; -}; - -//[pred,op](Ref a){return pred(a) ? op(a) : a; } -template -struct transform_if_fun { - transform_if_fun(Predicate _pred, Operator _op) : pred(_pred), op(_op) {} - template - void operator()(_T&& t) const { - using std::get; - if (pred(get<0>(t))) - get<1>(t) = op(get<0>(t)); - } - -private: - mutable Predicate pred; - mutable Operator op; -}; - -//[pred, op](Ref1 a, Ref2 s) { return pred(s) ? op(a) : a; }); -template -struct transform_if_unary_zip_mask_fun { - transform_if_unary_zip_mask_fun(Predicate _pred, Operator _op) : pred(_pred), op(_op) {} - template - void operator()(_T&& t) const { - using std::get; - if (pred(get<1>(t))) - get<2>(t) = op(get<0>(t)); - } - -private: - mutable Predicate pred; - mutable Operator op; -}; - -template -class transform_if_zip_mask_fun { -public: - transform_if_zip_mask_fun(Predicate _pred = oneapi::dpl::identity(), - BinaryOperation _op = oneapi::dpl::identity()) - : pred(_pred), op(_op) {} - template void operator()(_T &&t) const { - using std::get; - if (pred(get<2>(t))) - get<3>(t) = op(get<0>(t), get<1>(t)); - } - -private: - mutable Predicate pred; - mutable BinaryOperation op; -}; - -// This following code is similar to a section of code in -// oneDPL/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h -// It has a similar approach, and could be consolidated. -// Outside of some differences in approach, there are two significant -// differences in function. -// -// 1) This code allows the output type of the bit range translation to be fit -// into to the minimal type required to provide that many bits. The code in -// oneDPL to calculate the bucket for the radix is similar but its output is -// always std::uint32_t. The assumption that the bit range desired will fit in -// 32 bits is not true for this code. -// -// 2) This code ensures that for floating point type, -0.0f and 0.0f map to the -// same value. This allows the output of this translation to be used to provide -// a sort which ensures the stability of these values for floating point types. - -template struct uint_byte_map {}; -template <> struct uint_byte_map<1> { using type = uint8_t; }; -template <> struct uint_byte_map<2> { using type = uint16_t; }; -template <> struct uint_byte_map<4> { using type = uint32_t; }; -template <> struct uint_byte_map<8> { using type = uint64_t; }; - -template struct uint_map { - using type = typename uint_byte_map::type; -}; - -template class translate_key { - using uint_type_t = typename uint_map::type; - -public: - translate_key(int begin_bit, int end_bit) { - shift = begin_bit; - mask = ~OutKeyT(0); // all ones - mask = mask >> (sizeof(OutKeyT) * 8 - - (end_bit - begin_bit)); // setup appropriate mask - flip_sign = uint_type_t(1) << (sizeof(uint_type_t) * 8 - 1); // sign bit - flip_key = ~uint_type_t(0); // 0xF...F - } - - inline OutKeyT operator()(const T &key) const { - uint_type_t intermediate; - if constexpr (std::is_floating_point::value) { - // normal case (both -0.0f and 0.0f equal -0.0f) - if (key != T(-0.0f)) { - uint_type_t is_negative = reinterpret_cast(key) >> - (sizeof(uint_type_t) * 8 - 1); - intermediate = reinterpret_cast(key) ^ - ((is_negative * flip_key) | flip_sign); - } else // special case for -0.0f to keep stability with 0.0f - { - T negzero = T(-0.0f); - intermediate = reinterpret_cast(negzero); - } - } else if constexpr (std::is_signed::value) { - intermediate = reinterpret_cast(key) ^ flip_sign; - } else { - intermediate = key; - } - - return static_cast(intermediate >> shift) & - mask; // shift, cast, and mask - } - -private: - uint8_t shift; - OutKeyT mask; - uint_type_t flip_sign; - uint_type_t flip_key; -}; - -// Unary operator that returns reference to its argument. Ported from -// oneDPL: oneapi/dpl/pstl/utils.h -struct no_op_fun { - template Tp &&operator()(Tp &&a) const { - return ::std::forward(a); - } -}; - -// Unary functor which composes a pair of functors by calling them in succession -// on an input -template -struct __composition_functor { - __composition_functor(FunctorInner in, FunctorOuter out) - : _in(in), _out(out) {} - template T operator()(const T &i) const { - return _out(_in(i)); - } - FunctorInner _in; - FunctorOuter _out; -}; - -// Unary functor which maps an index of a ROI into a 2D flattened array -template struct __roi_2d_index_functor { - __roi_2d_index_functor(const OffsetT &num_cols, - const ::std::size_t &row_stride) - : _num_cols(num_cols), _row_stride(row_stride) {} - - template Index operator()(const Index &i) const { - return _row_stride * (i / _num_cols) + (i % _num_cols); - } - - OffsetT _num_cols; - ::std::size_t _row_stride; -}; - -// Unary functor which maps and index into an interleaved array by its active -// channel -template struct __interleaved_index_functor { - __interleaved_index_functor(const OffsetT &total_channels, - const OffsetT &active_channel) - : _total_channels(total_channels), _active_channel(active_channel) {} - - template Index operator()(const Index &i) const { - return i * _total_channels + _active_channel; - } - - OffsetT _total_channels; - OffsetT _active_channel; -}; - -} // end namespace internal - -} // end namespace dpct - -#endif diff --git a/dpct/dpl_extras/iterators.h b/dpct/dpl_extras/iterators.h deleted file mode 100644 index 2e1d10986..000000000 --- a/dpct/dpl_extras/iterators.h +++ /dev/null @@ -1,347 +0,0 @@ -//==---- iterators.h ------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_ITERATORS_H__ -#define __DPCT_ITERATORS_H__ - -#include - -#include "functional.h" - -namespace dpct { - -namespace internal { - -// Wrapper class returned from a dereferenced transform_iterator which was -// created using -// make_transform_output_iterator(). Used to apply the supplied transform -// function when writing into an object of this class. -// -// Example: -// int a[] = {0, 1, 2, 3, 4}; -// int* p = a; -// auto f = [](auto v) {return v*v;}; -// auto tr_out = dpct::make_transform_output_iterator(p+1, f); -// auto wrap = *tr_out; // wrap is a transform_output_ref_wrapper -// std::cout<<*(p+1)< class transform_output_ref_wrapper { -private: - T __my_reference_; - _UnaryFunc __my_unary_func_; - -public: - template - transform_output_ref_wrapper(U &&__reference, _UnaryFunc __unary_func) - : __my_reference_(std::forward(__reference)), - __my_unary_func_(__unary_func) {} - - // When writing to an object of this type, apply the supplied unary function, - // then write to the wrapped reference - template - transform_output_ref_wrapper &operator=(const UnaryInputType &e) { - __my_reference_ = __my_unary_func_(e); - return *this; - } -}; - -// Unary functor to create a transform_output_reference_wrapper when a -// transform_iterator is dereferenced, so that a -// the supplied unary function may be applied on write, resulting in a -// transform_output_iterator -template struct _Unary_Out { - _Unary_Out(_UnaryFunc __f_) : __f(__f_) {} - _UnaryFunc __f; - template auto operator()(T &&val) const { - return transform_output_ref_wrapper(std::forward(val), - __f); - } -}; - -} // end namespace internal - -using std::advance; - -using std::distance; - -template -oneapi::dpl::counting_iterator make_counting_iterator(const T &input) { - return oneapi::dpl::counting_iterator(input); -} - -template class constant_iterator { -public: - typedef std::false_type is_hetero; - typedef std::true_type is_passed_directly; - typedef std::ptrdiff_t difference_type; - typedef _Tp value_type; - typedef _Tp *pointer; - // There is no storage behind the iterator, so we return a value instead of - // reference. - typedef const _Tp reference; - typedef const _Tp const_reference; - typedef std::random_access_iterator_tag iterator_category; - - explicit constant_iterator(_Tp __init) - : __my_value_(__init), __my_counter_(0) {} - -private: - // used to construct iterator instances with different counter values required - // by arithmetic operators - constant_iterator(const _Tp &__value, const difference_type &__offset) - : __my_value_(__value), __my_counter_(__offset) {} - -public: - // non-const variants of access operators are not provided so unintended - // writes are caught at compile time. - const_reference operator*() const { return __my_value_; } - const_reference operator[](difference_type) const { return __my_value_; } - - difference_type operator-(const constant_iterator &__it) const { - return __my_counter_ - __it.__my_counter_; - } - - constant_iterator &operator+=(difference_type __forward) { - __my_counter_ += __forward; - return *this; - } - constant_iterator &operator-=(difference_type __backward) { - return *this += -__backward; - } - constant_iterator &operator++() { return *this += 1; } - constant_iterator &operator--() { return *this -= 1; } - - constant_iterator operator++(int) { - constant_iterator __it(*this); - ++(*this); - return __it; - } - constant_iterator operator--(int) { - constant_iterator __it(*this); - --(*this); - return __it; - } - - constant_iterator operator-(difference_type __backward) const { - return constant_iterator(__my_value_, __my_counter_ - __backward); - } - constant_iterator operator+(difference_type __forward) const { - return constant_iterator(__my_value_, __my_counter_ + __forward); - } - friend constant_iterator operator+(difference_type __forward, - const constant_iterator __it) { - return __it + __forward; - } - - bool operator==(const constant_iterator &__it) const { - return __my_value_ == __it.__my_value_ && - this->__my_counter_ == __it.__my_counter_; - } - bool operator!=(const constant_iterator &__it) const { - return !(*this == __it); - } - bool operator<(const constant_iterator &__it) const { - return *this - __it < 0; - } - bool operator>(const constant_iterator &__it) const { return __it < *this; } - bool operator<=(const constant_iterator &__it) const { - return !(*this > __it); - } - bool operator>=(const constant_iterator &__it) const { - return !(*this < __it); - } - -private: - _Tp __my_value_; - uint64_t __my_counter_; -}; - -template -constant_iterator<_Tp> make_constant_iterator(_Tp __value) { - return constant_iterator<_Tp>(__value); -} - -// key_value_pair class to represent a key and value, specifically a -// dereferenced arg_index_input_iterator -template class key_value_pair { -public: - key_value_pair() = default; - - key_value_pair(const _KeyTp &_key, const _ValueTp &_value) - : key(_key), value(_value) {} - - bool operator==(const key_value_pair<_KeyTp, _ValueTp> &_kvp) const { - return (key == _kvp.key) && (value == _kvp.value); - } - - bool operator!=(const key_value_pair<_KeyTp, _ValueTp> &_kvp) const { - return (key != _kvp.key) || (value != _kvp.value); - } - - _KeyTp key; - _ValueTp value; -}; - -namespace detail { - -template struct make_key_value_pair { - template - key_value_pair - operator()(const oneapi::dpl::__internal::tuple &tup) const { - return ::dpct::key_value_pair(::std::get<0>(tup), - ::std::get<1>(tup)); - } -}; - -template struct __zip_iterator_impl; -template struct __zip_iterator_impl> { - using type = oneapi::dpl::zip_iterator; -}; - -} // end namespace detail - -// dpct::zip_iterator can only accept std::tuple type as template argument for -// compatibility purpose. Please use oneapi::dpl::zip_iterator if you want to -// pass iterator's types directly. -template -using zip_iterator = typename detail::__zip_iterator_impl::type; - -// arg_index_input_iterator is an iterator over a input iterator, with a index. -// When dereferenced, it returns a key_value_pair, which can be interrogated for -// the index key or the value from the input iterator -template ::value_type> -class arg_index_input_iterator - : public oneapi::dpl::transform_iterator< - oneapi::dpl::zip_iterator, - InputIteratorT>, - detail::make_key_value_pair> { - using arg_index_input_iterator_wrap = oneapi::dpl::transform_iterator< - oneapi::dpl::zip_iterator, - InputIteratorT>, - detail::make_key_value_pair>; - -public: - typedef OffsetT difference_type; - - // signal to __get_sycl_range that this iterator is as a direct pass iterator - using is_zip = ::std::true_type; - - arg_index_input_iterator(const arg_index_input_iterator_wrap &__arg_wrap) - : arg_index_input_iterator_wrap(__arg_wrap) {} - arg_index_input_iterator(InputIteratorT __iter) - : arg_index_input_iterator_wrap( - oneapi::dpl::make_zip_iterator( - oneapi::dpl::counting_iterator(OffsetT(0)), __iter), - detail::make_key_value_pair()) {} - - arg_index_input_iterator &operator=(const arg_index_input_iterator &__input) { - arg_index_input_iterator_wrap::operator=(__input); - return *this; - } - arg_index_input_iterator &operator++() { - arg_index_input_iterator_wrap::operator++(); - return *this; - } - arg_index_input_iterator &operator--() { - arg_index_input_iterator_wrap::operator--(); - return *this; - } - arg_index_input_iterator operator++(int) { - arg_index_input_iterator __it(*this); - ++(*this); - return __it; - } - arg_index_input_iterator operator--(int) { - arg_index_input_iterator __it(*this); - --(*this); - return __it; - } - arg_index_input_iterator operator+(difference_type __forward) const { - return arg_index_input_iterator( - arg_index_input_iterator_wrap::operator+(__forward)); - } - arg_index_input_iterator operator-(difference_type __backward) const { - return arg_index_input_iterator( - arg_index_input_iterator_wrap::operator-(__backward)); - } - arg_index_input_iterator &operator+=(difference_type __forward) { - arg_index_input_iterator_wrap::operator+=(__forward); - return *this; - } - arg_index_input_iterator &operator-=(difference_type __backward) { - arg_index_input_iterator_wrap::operator-=(__backward); - return *this; - } - - friend arg_index_input_iterator - operator+(difference_type __forward, const arg_index_input_iterator &__it) { - return __it + __forward; - } - - difference_type operator-(const arg_index_input_iterator &__it) const { - return arg_index_input_iterator_wrap::operator-(__it); - } - bool operator==(const arg_index_input_iterator &__it) const { - return arg_index_input_iterator_wrap::operator==(__it); - } - bool operator!=(const arg_index_input_iterator &__it) const { - return !(*this == __it); - } - bool operator<(const arg_index_input_iterator &__it) const { - return *this - __it < 0; - } - bool operator>(const arg_index_input_iterator &__it) const { - return __it < *this; - } - bool operator<=(const arg_index_input_iterator &__it) const { - return !(*this > __it); - } - bool operator>=(const arg_index_input_iterator &__it) const { - return !(*this < __it); - } - - // returns an arg_index_input_iterator with the same iter position, but a - // count reset to 0 - arg_index_input_iterator create_normalized() { - return arg_index_input_iterator( - ::std::get<1>(arg_index_input_iterator_wrap::base().base())); - } -}; - -template struct io_iterator_pair { - inline io_iterator_pair() : selector(false) {} - - inline io_iterator_pair(const IterT &first, const IterT &second) - : selector(false) { - iter[0] = first; - iter[1] = second; - } - - inline IterT first() const { return selector ? iter[1] : iter[0]; } - - inline IterT second() const { return selector ? iter[0] : iter[1]; } - - inline void swap() { selector = !selector; } - - bool selector; - - IterT iter[2]; -}; - -template -auto make_transform_output_iterator(_Iter __it, _UnaryFunc __unary_func) { - return oneapi::dpl::transform_iterator( - __it, internal::_Unary_Out<_UnaryFunc>(__unary_func)); -} - -} // end namespace dpct - -#endif diff --git a/dpct/dpl_extras/memory.h b/dpct/dpl_extras/memory.h deleted file mode 100644 index 08b965133..000000000 --- a/dpct/dpl_extras/memory.h +++ /dev/null @@ -1,1024 +0,0 @@ -//==---- memory.h ---------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_MEMORY_H__ -#define __DPCT_MEMORY_H__ - -#include -#include -#include "functional.h" - -// Memory management section: -// device_pointer, device_reference, swap, device_iterator, malloc_device, -// device_new, free_device, device_delete -namespace dpct { -namespace detail { -template -struct make_allocatable -{ - using type = T; -}; - -template <> -struct make_allocatable -{ - using type = dpct::byte_t; -}; - -#if defined(__LIBSYCL_MAJOR_VERSION) && defined(__LIBSYCL_MINOR_VERSION) && \ - defined(__LIBSYCL_PATCH_VERSION) -#define _DPCT_LIBSYCL_VERSION \ - (__LIBSYCL_MAJOR_VERSION * 10000 + __LIBSYCL_MINOR_VERSION * 100 + \ - __LIBSYCL_PATCH_VERSION) -#else -#define _DPCT_LIBSYCL_VERSION 0 -#endif - -template -using __buffer_allocator = -#if _DPCT_LIBSYCL_VERSION >= 60000 - sycl::buffer_allocator::type>; -#else - sycl::buffer_allocator; -#endif -} // namespace detail - -#ifdef DPCT_USM_LEVEL_NONE -template > -class device_pointer; -#else -template class device_pointer; -#endif - -template struct device_reference { - using pointer = device_pointer; - using value_type = T; - template - device_reference(const device_reference &input) - : value(input.value) {} - device_reference(const pointer &input) : value((*input).value) {} - device_reference(value_type &input) : value(input) {} - template - device_reference &operator=(const device_reference &input) { - value = input; - return *this; - }; - device_reference &operator=(const device_reference &input) { - T val = input.value; - value = val; - return *this; - }; - device_reference &operator=(const value_type &x) { - value = x; - return *this; - }; - pointer operator&() const { return pointer(&value); }; - operator value_type() const { return T(value); } - device_reference &operator++() { - ++value; - return *this; - }; - device_reference &operator--() { - --value; - return *this; - }; - device_reference operator++(int) { - device_reference ref(*this); - ++(*this); - return ref; - }; - device_reference operator--(int) { - device_reference ref(*this); - --(*this); - return ref; - }; - device_reference &operator+=(const T &input) { - value += input; - return *this; - }; - device_reference &operator-=(const T &input) { - value -= input; - return *this; - }; - device_reference &operator*=(const T &input) { - value *= input; - return *this; - }; - device_reference &operator/=(const T &input) { - value /= input; - return *this; - }; - device_reference &operator%=(const T &input) { - value %= input; - return *this; - }; - device_reference &operator&=(const T &input) { - value &= input; - return *this; - }; - device_reference &operator|=(const T &input) { - value |= input; - return *this; - }; - device_reference &operator^=(const T &input) { - value ^= input; - return *this; - }; - device_reference &operator<<=(const T &input) { - value <<= input; - return *this; - }; - device_reference &operator>>=(const T &input) { - value >>= input; - return *this; - }; - void swap(device_reference &input) { - T tmp = (*this); - *this = (input); - input = (tmp); - } - T &value; -}; - -template -void swap(device_reference &x, device_reference &y) { - x.swap(y); -} - -template void swap(T &x, T &y) { - T tmp = x; - x = y; - y = tmp; -} - -template -::std::ostream &operator<<(::std::ostream &out, - const device_reference &ref) { - return out << T(ref); -} - -namespace internal { -// struct for checking if iterator is heterogeneous or not -template // for non-heterogeneous iterators -struct is_hetero_iterator : std::false_type {}; - -template // for heterogeneous iterators -struct is_hetero_iterator< - Iter, typename std::enable_if::type> - : std::true_type {}; -} // namespace internal - -#ifdef DPCT_USM_LEVEL_NONE -// Must be forward declared due to default argument -template -device_pointer device_new(device_pointer, const T &, - const std::size_t = 1); - -template -class device_iterator; - -template -class device_pointer_base { -protected: - sycl::buffer buffer; - std::size_t idx; - - // Declare friend to give access to protected buffer and idx members - template - friend device_pointer device_new(device_pointer, const T &, - const std::size_t); - -public: - using pointer = ValueType *; - using difference_type = std::make_signed::type; - - device_pointer_base(sycl::buffer in, std::size_t i = 0) - : buffer(in), idx(i) {} -#ifdef __USE_DPCT - template - device_pointer_base(OtherT *ptr) - : buffer( - dpct::detail::mem_mgr::instance() - .translate_ptr(ptr) - .buffer.template reinterpret(sycl::range<1>( - dpct::detail::mem_mgr::instance().translate_ptr(ptr).size / - sizeof(ValueType)))), - idx(ptr - (ValueType*)dpct::detail::mem_mgr::instance() - .translate_ptr(ptr).alloc_ptr) {} -#endif - device_pointer_base(const std::size_t count) - : buffer(sycl::range<1>(count / sizeof(ValueType))), idx() {} - // buffer has no default ctor we pass zero-range to create an empty buffer - device_pointer_base() : buffer(sycl::range<1>(0)) {} - device_pointer_base(const device_pointer_base &in) - : buffer(in.buffer), idx(in.idx) {} - pointer get() const { - auto res = - (const_cast(this) - ->buffer.template get_access()) - .get_pointer(); - return res + idx; - } - operator ValueType *() { - auto res = (buffer.template get_access()) - .get_pointer(); - return res + idx; - } - operator ValueType *() const { - auto res = - (const_cast(this) - ->buffer.template get_access()) - .get_pointer(); - return res + idx; - } - Derived operator+(difference_type forward) const { - return Derived{buffer, idx + forward}; - } - Derived operator-(difference_type backward) const { - return Derived{buffer, idx - backward}; - } - Derived operator++(int) { - Derived p(buffer, idx); - idx += 1; - return p; - } - Derived operator--(int) { - Derived p(buffer, idx); - idx -= 1; - return p; - } - difference_type operator-(const Derived &it) const { return idx - it.idx; } - template - typename std::enable_if::value, - difference_type>::type - operator-(const OtherIterator &it) const { - return idx - std::distance(oneapi::dpl::begin(buffer), it); - } - - std::size_t get_idx() const { return idx; } // required - - sycl::buffer get_buffer() { - return buffer; - } // required -}; - -template -class device_pointer - : public device_pointer_base> { -private: - using base_type = - device_pointer_base; - -public: - using value_type = dpct::byte_t; - using difference_type = std::make_signed::type; - using pointer = void *; - using reference = value_type &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::true_type; // required - using is_passed_directly = std::false_type; - static constexpr sycl::access_mode mode = Mode; // required - - device_pointer(sycl::buffer in, std::size_t i = 0) - : base_type(in, i) {} -#ifdef __USE_DPCT - template device_pointer(OtherT *ptr) : base_type(ptr) {} -#endif - // needed for malloc_device, count is number of bytes to allocate - device_pointer(const std::size_t count) : base_type(count) {} - device_pointer() : base_type() {} - device_pointer(const device_pointer &in) : base_type(in) {} - device_pointer &operator+=(difference_type forward) { - this->idx += forward; - return *this; - } - device_pointer &operator-=(difference_type backward) { - this->idx -= backward; - return *this; - } - // include operators from base class - using base_type::operator++; - using base_type::operator--; - device_pointer &operator++() { - this->idx += 1; - return *this; - } - device_pointer &operator--() { - this->idx -= 1; - return *this; - } -}; - -template -class device_pointer - : public device_pointer_base> { -private: - using base_type = device_pointer_base; - -public: - using value_type = T; - using difference_type = std::make_signed::type; - using pointer = T *; - using reference = T &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::true_type; // required - using is_passed_directly = std::false_type; - static constexpr sycl::access_mode mode = Mode; // required - - device_pointer(sycl::buffer in, std::size_t i = 0) : base_type(in, i) {} -#ifdef __USE_DPCT - template device_pointer(OtherT *ptr) : base_type(ptr) {} -#endif - // needed for malloc_device, count is number of bytes to allocate - device_pointer(const std::size_t count) : base_type(count) {} - device_pointer() : base_type() {} - device_pointer(const device_pointer &in) : base_type(in) {} - device_pointer &operator+=(difference_type forward) { - this->idx += forward; - return *this; - } - device_pointer &operator-=(difference_type backward) { - this->idx -= backward; - return *this; - } - operator device_pointer() { - auto converted_buf = (this->buffer) - .template reinterpret(sycl::range<1>( - sizeof(value_type) * this->buffer.size())); - return device_pointer(converted_buf, this->idx); - } - // include operators from base class - using base_type::operator++; - using base_type::operator--; - device_pointer &operator++() { - this->idx += 1; - return *this; - } - device_pointer &operator--() { - this->idx -= 1; - return *this; - } -}; -#else -template class device_iterator; - -template class device_pointer_base { -protected: - ValueType *ptr; - -public: - using pointer = ValueType *; - using difference_type = std::make_signed::type; - - device_pointer_base(ValueType *p) : ptr(p) {} - device_pointer_base(const std::size_t count) { - sycl::queue default_queue = dpct::get_default_queue(); - ptr = static_cast(sycl::malloc_shared( - count, default_queue.get_device(), default_queue.get_context())); - } - device_pointer_base() {} - pointer get() const { return ptr; } - operator ValueType *() { return ptr; } - operator ValueType *() const { return ptr; } - - ValueType &operator[](difference_type idx) { return ptr[idx]; } - ValueType &operator[](difference_type idx) const { return ptr[idx]; } - - Derived operator+(difference_type forward) const { - return Derived{ptr + forward}; - } - Derived operator-(difference_type backward) const { - return Derived{ptr - backward}; - } - Derived operator++(int) { - Derived p(ptr); - ++ptr; - return p; - } - Derived operator--(int) { - Derived p(ptr); - --ptr; - return p; - } - difference_type operator-(const Derived &it) const { return ptr - it.ptr; } -}; - -template <> -class device_pointer - : public device_pointer_base> { -private: - using base_type = device_pointer_base>; - -public: - using value_type = dpct::byte_t; - using difference_type = std::make_signed::type; - using pointer = void *; - using reference = value_type &; - using const_reference = const value_type &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::false_type; // required - using is_passed_directly = std::true_type; // required - - device_pointer(void *p) : base_type(static_cast(p)) {} - // needed for malloc_device, count is number of bytes to allocate - device_pointer(const std::size_t count) : base_type(count) {} - device_pointer() : base_type() {} - pointer get() const { return static_cast(this->ptr); } - operator void *() { return this->ptr; } - operator void *() const { return this->ptr; } - - // include operators from base class - using base_type::operator++; - using base_type::operator--; - device_pointer &operator++() { - ++(this->ptr); - return *this; - } - device_pointer &operator--() { - --(this->ptr); - return *this; - } - device_pointer &operator+=(difference_type forward) { - this->ptr = this->ptr + forward; - return *this; - } - device_pointer &operator-=(difference_type backward) { - this->ptr = this->ptr - backward; - return *this; - } -}; - -template -class device_pointer : public device_pointer_base> { -private: - using base_type = device_pointer_base>; - -public: - using value_type = T; - using difference_type = std::make_signed::type; - using pointer = T *; - using reference = T &; - using const_reference = const T &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::false_type; // required - using is_passed_directly = std::true_type; // required - - device_pointer(T *p) : base_type(p) {} - // needed for malloc_device, count is number of bytes to allocate - device_pointer(const std::size_t count) : base_type(count) {} - device_pointer() : base_type() {} - device_pointer &operator=(const device_iterator &in) { - this->ptr = static_cast>(in).ptr; - return *this; - } - operator device_pointer() { - return device_pointer(static_cast(this->ptr)); - } - // include operators from base class - using base_type::operator++; - using base_type::operator--; - device_pointer &operator++() { - ++(this->ptr); - return *this; - } - device_pointer &operator--() { - --(this->ptr); - return *this; - } - device_pointer &operator+=(difference_type forward) { - this->ptr = this->ptr + forward; - return *this; - } - device_pointer &operator-=(difference_type backward) { - this->ptr = this->ptr - backward; - return *this; - } -}; -#endif - -#ifdef DPCT_USM_LEVEL_NONE -template > -class device_iterator : public device_pointer { - using Base = device_pointer; - -public: - using value_type = T; - using difference_type = std::make_signed::type; - using pointer = T *; - using reference = T &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::true_type; // required - using is_passed_directly = std::false_type; // required - static constexpr sycl::access_mode mode = Mode; // required - - device_iterator() : Base() {} - device_iterator(sycl::buffer vec, std::size_t index) - : Base(vec, index) {} - device_iterator(const Base &dev_ptr) : Base(dev_ptr) {} - template - device_iterator(const device_iterator &in) - : Base(in.buffer, in.idx) {} // required for iter_mode - device_iterator &operator=(const device_iterator &in) { - Base::buffer = in.buffer; - Base::idx = in.idx; - return *this; - } - - reference operator*() const { - return const_cast(this) - ->buffer.template get_access()[Base::idx]; - } - - reference operator[](difference_type i) const { return *(*this + i); } - device_iterator &operator++() { - ++Base::idx; - return *this; - } - device_iterator &operator--() { - --Base::idx; - return *this; - } - device_iterator operator++(int) { - device_iterator it(*this); - ++(*this); - return it; - } - device_iterator operator--(int) { - device_iterator it(*this); - --(*this); - return it; - } - device_iterator operator+(difference_type forward) const { - const auto new_idx = Base::idx + forward; - return {Base::buffer, new_idx}; - } - device_iterator &operator+=(difference_type forward) { - Base::idx += forward; - return *this; - } - device_iterator operator-(difference_type backward) const { - return {Base::buffer, Base::idx - backward}; - } - device_iterator &operator-=(difference_type backward) { - Base::idx -= backward; - return *this; - } - friend device_iterator operator+(difference_type forward, - const device_iterator &it) { - return it + forward; - } - difference_type operator-(const device_iterator &it) const { - return Base::idx - it.idx; - } - template - typename std::enable_if::value, - difference_type>::type - operator-(const OtherIterator &it) const { - return Base::idx - std::distance(oneapi::dpl::begin(Base::buffer), it); - } - bool operator==(const device_iterator &it) const { return *this - it == 0; } - bool operator!=(const device_iterator &it) const { return !(*this == it); } - bool operator<(const device_iterator &it) const { return *this - it < 0; } - bool operator>(const device_iterator &it) const { return it < *this; } - bool operator<=(const device_iterator &it) const { return !(*this > it); } - bool operator>=(const device_iterator &it) const { return !(*this < it); } - - std::size_t get_idx() const { return Base::idx; } // required - - sycl::buffer get_buffer() { - return Base::buffer; - } // required -}; -#else -template class device_iterator : public device_pointer { - using Base = device_pointer; - -protected: - std::size_t idx; - -public: - using value_type = T; - using difference_type = std::make_signed::type; - using pointer = typename Base::pointer; - using reference = typename Base::reference; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = std::false_type; // required - using is_passed_directly = std::true_type; // required - static constexpr sycl::access_mode mode = - sycl::access_mode::read_write; // required - - device_iterator() : Base(nullptr), idx(0) {} - device_iterator(T *vec, std::size_t index) : Base(vec), idx(index) {} - device_iterator(const Base &dev_ptr) : Base(dev_ptr), idx(0) {} - template - device_iterator(const device_iterator &in) - : Base(in.ptr), idx(in.idx) {} // required for iter_mode - device_iterator &operator=(const device_iterator &in) { - Base::operator=(in); - idx = in.idx; - return *this; - } - - reference operator*() const { return *(Base::ptr + idx); } - - reference operator[](difference_type i) { return Base::ptr[idx + i]; } - reference operator[](difference_type i) const { return Base::ptr[idx + i]; } - device_iterator &operator++() { - ++idx; - return *this; - } - device_iterator &operator--() { - --idx; - return *this; - } - device_iterator operator++(int) { - device_iterator it(*this); - ++(*this); - return it; - } - device_iterator operator--(int) { - device_iterator it(*this); - --(*this); - return it; - } - device_iterator operator+(difference_type forward) const { - const auto new_idx = idx + forward; - return {Base::ptr, new_idx}; - } - device_iterator &operator+=(difference_type forward) { - idx += forward; - return *this; - } - device_iterator operator-(difference_type backward) const { - return {Base::ptr, idx - backward}; - } - device_iterator &operator-=(difference_type backward) { - idx -= backward; - return *this; - } - friend device_iterator operator+(difference_type forward, - const device_iterator &it) { - return it + forward; - } - difference_type operator-(const device_iterator &it) const { - return idx - it.idx; - } - - template - typename std::enable_if::value, - difference_type>::type - operator-(const OtherIterator &it) const { - return idx - it.get_idx(); - } - - bool operator==(const device_iterator &it) const { return *this - it == 0; } - bool operator!=(const device_iterator &it) const { return !(*this == it); } - bool operator<(const device_iterator &it) const { return *this - it < 0; } - bool operator>(const device_iterator &it) const { return it < *this; } - bool operator<=(const device_iterator &it) const { return !(*this > it); } - bool operator>=(const device_iterator &it) const { return !(*this < it); } - - std::size_t get_idx() const { return idx; } // required - - device_iterator &get_buffer() { return *this; } // required - - std::size_t size() const { return idx; } -}; -#endif - -struct sys_tag {}; -struct device_sys_tag : public sys_tag {}; -struct host_sys_tag : public sys_tag {}; - -#ifdef DPCT_USM_LEVEL_NONE -template class tagged_pointer { - static_assert(false, - "tagged_pointer is not supported with DPCT_USM_LEVEL_NONE"); -}; -template -void release_temporary_allocation(PolicyOrTag &&policy_or_tag, Pointer ptr) { - static_assert( - false, - "release_temporary_allocation is not supported with DPCT_USM_LEVEL_NONE"); -} -template -auto get_temporary_allocation(PolicyOrTag &&policy_or_tag, - SizeType num_elements) { - static_assert( - false, - "get_temporary_allocation is not supported with DPCT_USM_LEVEL_NONE"); -} -template -auto malloc(PolicyOrTag &&policy_or_tag, const ::std::size_t num_bytes) { - static_assert(false, "malloc is not supported with DPCT_USM_LEVEL_NONE"); -} -template -auto malloc(PolicyOrTag &&policy_or_tag, const ::std::size_t num_elements) { - static_assert(false, "malloc is not supported with DPCT_USM_LEVEL_NONE"); -} -template -void free(PolicyOrTag &&policy_or_tag, Pointer ptr) { - static_assert(false, "free is not supported with DPCT_USM_LEVEL_NONE"); -} -#else -namespace internal { - -// Utility that converts a policy to a tag or reflects a provided tag -template struct policy_or_tag_to_tag { -private: - using decayed_policy_or_tag_t = ::std::decay_t; - using policy_conversion = ::std::conditional_t< - !is_hetero_execution_policy::value, host_sys_tag, - device_sys_tag>; - static constexpr bool is_policy_v = - oneapi::dpl::execution::is_execution_policy_v; - static constexpr bool is_sys_tag_v = ::std::disjunction_v< - ::std::is_same, - ::std::is_same>; - static_assert(is_policy_v || is_sys_tag_v, - "Only oneDPL policies or system tags may be provided"); - -public: - using type = ::std::conditional_t; -}; - -template -using policy_or_tag_to_tag_t = typename policy_or_tag_to_tag::type; - -template struct is_host_policy_or_tag { -private: - using tag_t = policy_or_tag_to_tag_t; - -public: - static constexpr bool value = ::std::is_same_v; -}; - -template -inline constexpr bool is_host_policy_or_tag_v = - is_host_policy_or_tag::value; - -} // namespace internal - -// TODO: Make this class an iterator adaptor. -// tagged_pointer provides a wrapper around a raw pointer type with a tag of the -// location of the allocated memory. Standard pointer operations are supported -// with this class. -template class tagged_pointer { -public: - using value_type = T; - using difference_type = ::std::ptrdiff_t; - using pointer = T *; - using reference = T &; - using iterator_category = std::random_access_iterator_tag; - using is_hetero = ::std::false_type; - using is_passed_directly = std::true_type; - - tagged_pointer() : m_ptr(nullptr) {} - tagged_pointer(T *ptr) : m_ptr(ptr) {} - T &operator[](difference_type idx) { return this->m_ptr[idx]; } - const T &operator[](difference_type idx) const { return this->m_ptr[idx]; } - tagged_pointer operator+(difference_type forward) const { - return tagged_pointer{this->m_ptr + forward}; - } - tagged_pointer operator-(difference_type backward) const { - return tagged_pointer{this->m_ptr - backward}; - } - operator const T *() const { return m_ptr; } - operator T *() { return m_ptr; } - T &operator*() { return *this->m_ptr; } - const T &operator*() const { return *this->m_ptr; } - T *operator->() { return this->m_ptr; } - const T *operator->() const { return this->m_ptr; } - tagged_pointer operator++(int) { - tagged_pointer p(this->m_ptr); - ++this->m_ptr; - return p; - } - tagged_pointer operator--(int) { - tagged_pointer p(this->m_ptr); - --this->m_ptr; - return p; - } - tagged_pointer &operator++() { - ++this->m_ptr; - return *this; - } - tagged_pointer &operator--() { - --this->m_ptr; - return *this; - } - difference_type operator-(const tagged_pointer &it) const { - return this->m_ptr - it.m_ptr; - } - tagged_pointer &operator+=(difference_type forward) { - this->m_ptr = this->m_ptr + forward; - return *this; - } - tagged_pointer &operator-=(difference_type backward) { - this->m_ptr = this->m_ptr - backward; - return *this; - } - -private: - T *m_ptr; -}; - -// Void specialization for tagged pointers. Iterator traits are not provided but -// conversion to other non-void tagged pointers is allowed. Pointer arithmetic -// is disallowed with this specialization. -template class tagged_pointer { -public: - using difference_type = ::std::ptrdiff_t; - using pointer = void *; - tagged_pointer() : m_ptr(nullptr) {} - tagged_pointer(pointer ptr) : m_ptr(ptr) {} - operator const void *() const { return m_ptr; } - operator void *() { return m_ptr; } - // Enable tagged void pointer to convert to all other raw pointer types. - template operator OtherPtr *() const { - return static_cast(this->m_ptr); - } - -private: - void *m_ptr; -}; - -namespace internal { - -// Internal utility to return raw pointer to allocated memory. Note that host -// allocations are not device accessible (not pinned). -template -void *malloc_base(PolicyOrTag &&policy_or_tag, const ::std::size_t num_bytes) { - using decayed_policy_or_tag_t = ::std::decay_t; - if constexpr (internal::is_host_policy_or_tag_v) { - return ::std::malloc(num_bytes); - } else { - sycl::queue q; - // Grab the associated queue if a device policy is provided. Otherwise, use - // default constructed. - if constexpr (oneapi::dpl::execution::is_execution_policy_v< - decayed_policy_or_tag_t>) { - q = policy_or_tag.queue(); - } else { - q = get_default_queue(); - } - return sycl::malloc_shared(num_bytes, q); - } -} - -} // namespace internal - -template -auto malloc(PolicyOrTag &&policy_or_tag, const ::std::size_t num_bytes) { - return tagged_pointer>( - internal::malloc_base(::std::forward(policy_or_tag), - num_bytes)); -} - -template -auto malloc(PolicyOrTag &&policy_or_tag, const ::std::size_t num_elements) { - return tagged_pointer>( - static_cast( - internal::malloc_base(::std::forward(policy_or_tag), - num_elements * sizeof(T)))); -} - -template -void free(PolicyOrTag &&policy_or_tag, Pointer ptr) { - using decayed_policy_or_tag_t = ::std::decay_t; - if constexpr (internal::is_host_policy_or_tag_v) { - ::std::free(ptr); - } else { - sycl::queue q; - // Grab the associated queue if a device policy is provided. Otherwise, use - // default constructed. - if constexpr (oneapi::dpl::execution::is_execution_policy_v< - decayed_policy_or_tag_t>) { - q = policy_or_tag.queue(); - } else { - q = get_default_queue(); - } - sycl::free(ptr, q); - } -} - -template -auto get_temporary_allocation(PolicyOrTag &&policy_or_tag, - SizeType num_elements) { - auto allocation_ptr = - dpct::malloc(::std::forward(policy_or_tag), num_elements); - if (allocation_ptr == nullptr) - return ::std::make_pair(allocation_ptr, SizeType(0)); - return ::std::make_pair(allocation_ptr, num_elements); -} - -template -void release_temporary_allocation(PolicyOrTag &&policy_or_tag, Pointer ptr) { - dpct::free(::std::forward(policy_or_tag), ptr); -} -#endif - -template -device_pointer malloc_device(const std::size_t num_elements) { - return device_pointer(num_elements * sizeof(T)); -} -static inline device_pointer malloc_device(const std::size_t num_bytes) { - return device_pointer(num_bytes); -} -#ifdef DPCT_USM_LEVEL_NONE -template -device_pointer device_new(device_pointer p, const T &value, - const std::size_t count) { - auto converted_buf = p.buffer.template reinterpret(sycl::range<1>(count)); - ::std::uninitialized_fill( - oneapi::dpl::execution::make_device_policy(dpct::get_default_queue()), - oneapi::dpl::begin(converted_buf), - oneapi::dpl::end(converted_buf), value); - return device_pointer(converted_buf, p.idx); -} -// buffer manages lifetime -template void free_device(device_pointer ptr) {} -#else -template -device_pointer device_new(device_pointer p, const T &value, - const std::size_t count = 1) { - dpct::device_pointer converted_p(static_cast(p.get())); - ::std::uninitialized_fill( - oneapi::dpl::execution::make_device_policy(dpct::get_default_queue()), - converted_p, converted_p + count, value); - return converted_p; -} -template void free_device(device_pointer ptr) { - sycl::free(ptr.get(), dpct::get_default_queue()); -} -#endif -template -device_pointer device_new(device_pointer p, - const std::size_t count = 1) { - return device_new(p, T{}, count); -} -template -device_pointer device_new(const std::size_t count = 1) { - return device_new(device_pointer(sizeof(T) * count), T{}, count); -} - -template -typename std::enable_if::value, void>::type -device_delete(device_pointer p, const std::size_t count = 1) { - ::std::destroy(oneapi::dpl::execution::make_device_policy(dpct::get_default_queue()), - p, p + count); - free_device(p); -} -template -typename std::enable_if::value, void>::type -device_delete(device_pointer p, const std::size_t count = 1) { - free_device(p); -} - -template device_pointer get_device_pointer(T *ptr) { - return device_pointer(ptr); -} - -template -device_pointer get_device_pointer(const device_pointer &ptr) { - return device_pointer(ptr); -} - -template T *get_raw_pointer(const device_pointer &ptr) { - return ptr.get(); -} - -template Pointer get_raw_pointer(const Pointer &ptr) { - return ptr; -} - -template const T &get_raw_reference(const device_reference &ref) { - return ref.value; -} - -template T &get_raw_reference(device_reference &ref) { - return ref.value; -} - -template const T &get_raw_reference(const T &ref) { - return ref; -} - -template T &get_raw_reference(T &ref) { - return ref; -} - -} // namespace dpct - -#endif diff --git a/dpct/dpl_extras/numeric.h b/dpct/dpl_extras/numeric.h deleted file mode 100644 index 9864cd173..000000000 --- a/dpct/dpl_extras/numeric.h +++ /dev/null @@ -1,32 +0,0 @@ -//==---- numeric.h --------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_NUMERIC_H__ -#define __DPCT_NUMERIC_H__ - -namespace dpct { - -template -T inner_product(Policy &&policy, InputIt1 first1, InputIt1 last1, - InputIt2 first2, T init) { - return std::transform_reduce(std::forward(policy), first1, last1, - first2, init); -} - -template -T inner_product(Policy &&policy, InputIt1 first1, InputIt1 last1, - InputIt2 first2, T init, BinaryOperation1 op1, - BinaryOperation2 op2) { - return std::transform_reduce(std::forward(policy), first1, last1, - first2, init, op1, op2); -} - -} // end namespace dpct - -#endif diff --git a/dpct/dpl_extras/vector.h b/dpct/dpl_extras/vector.h deleted file mode 100644 index afba575ae..000000000 --- a/dpct/dpl_extras/vector.h +++ /dev/null @@ -1,752 +0,0 @@ -//==---- vector.h ---------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_VECTOR_H__ -#define __DPCT_VECTOR_H__ - -#include -#include - -#include - -#include "memory.h" - -#include -#include -#include - -#include "../device.hpp" - -namespace dpct { - -namespace internal { -template // for non-iterators -struct is_iterator : std::false_type {}; - -template // For iterators -struct is_iterator< - Iter, - typename std::enable_if< - !std::is_void::value, void>::type> - : std::true_type {}; - -template // For pointers -struct is_iterator : std::true_type {}; -} // end namespace internal - -#ifndef DPCT_USM_LEVEL_NONE - -template > -class device_vector { -public: - using iterator = device_iterator; - using const_iterator = const iterator; - using reference = device_reference; - using const_reference = const reference; - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using difference_type = - typename ::std::iterator_traits::difference_type; - using size_type = ::std::size_t; - -private: - Allocator _alloc; - size_type _size; - size_type _capacity; - pointer _storage; - - size_type _min_capacity() const { return size_type(1); } - - void _set_capacity_and_alloc() { - _capacity = ::std::max(_size * 2, _min_capacity()); - _storage = _alloc.allocate(_capacity); - } - -public: - template operator ::std::vector() const { - auto __tmp = ::std::vector(this->size()); - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - this->begin(), this->end(), __tmp.begin()); - return __tmp; - } - device_vector() - : _alloc(get_default_queue()), _size(0), _capacity(_min_capacity()) { - _set_capacity_and_alloc(); - } - ~device_vector() /*= default*/ { _alloc.deallocate(_storage, _capacity); }; - explicit device_vector(size_type n) : device_vector(n, T()) {} - explicit device_vector(size_type n, const T &value) - : _alloc(get_default_queue()), _size(n) { - _set_capacity_and_alloc(); - if (_size > 0) { - ::std::fill(oneapi::dpl::execution::make_device_policy(get_default_queue()), - begin(), end(), T(value)); - } - } - device_vector(const device_vector &other) : _alloc(get_default_queue()) { - _size = other.size(); - _capacity = other.capacity(); - _storage = _alloc.allocate(_capacity); - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - other.begin(), other.end(), begin()); - } - device_vector(device_vector &&other) - : _alloc(get_default_queue()), _size(other.size()), - _capacity(other.capacity()), _storage(other._storage) { - other._size = 0; - other._capacity = 0; - other._storage = nullptr; - } - - template - device_vector(InputIterator first, - typename ::std::enable_if< - internal::is_iterator::value && - !::std::is_pointer::value && - ::std::is_same::iterator_category, - ::std::random_access_iterator_tag>::value, - InputIterator>::type last) - : _alloc(get_default_queue()) { - _size = ::std::distance(first, last); - _set_capacity_and_alloc(); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, begin()); - } - } - - template - device_vector(InputIterator first, - typename ::std::enable_if<::std::is_pointer::value, - InputIterator>::type last) - : _alloc(get_default_queue()) { - _size = ::std::distance(first, last); - _set_capacity_and_alloc(); - if (_size > 0) { - auto ptr_type = sycl::get_pointer_type(first, get_default_context()); - if (ptr_type != sycl::usm::alloc::host && - ptr_type != sycl::usm::alloc::unknown) { - ::std::copy( - oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, begin()); - } else { - sycl::buffer::value_type, - 1> - buf(first, last); - auto buf_first = oneapi::dpl::begin(buf); - auto buf_last = oneapi::dpl::end(buf); - ::std::copy( - oneapi::dpl::execution::make_device_policy(get_default_queue()), - buf_first, buf_last, begin()); - } - } - } - - template - device_vector(InputIterator first, - typename ::std::enable_if< - internal::is_iterator::value && - !::std::is_pointer::value && - !::std::is_same::iterator_category, - ::std::random_access_iterator_tag>::value, - InputIterator>::type last) - : _alloc(get_default_queue()), _size(::std::distance(first, last)) { - _set_capacity_and_alloc(); - ::std::vector _tmp(first, last); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - _tmp.begin(), _tmp.end(), this->begin()); - } - } - - template - device_vector(const device_vector &v) - : _alloc(get_default_queue()), _storage(v.real_begin()), _size(v.size()), - _capacity(v.capacity()) {} - - template - device_vector(::std::vector &v) - : _alloc(get_default_queue()), _size(v.size()) { - _set_capacity_and_alloc(); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - v.begin(), v.end(), this->begin()); - } - } - - template - device_vector &operator=(const ::std::vector &v) { - resize(v.size()); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - v.begin(), v.end(), begin()); - } - return *this; - } - device_vector &operator=(const device_vector &other) { - // Copy assignment operator: - resize(other.size()); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - other.begin(), other.end(), begin()); - } - return *this; - } - device_vector &operator=(device_vector &&other) { - // Move assignment operator: - device_vector dummy(::std::move(other)); - this->swap(dummy); - return *this; - } - size_type size() const { return _size; } - iterator begin() noexcept { return device_iterator(_storage, 0); } - iterator end() { return device_iterator(_storage, size()); } - const_iterator begin() const noexcept { - return device_iterator(_storage, 0); - } - const_iterator cbegin() const noexcept { return begin(); } - const_iterator end() const { return device_iterator(_storage, size()); } - const_iterator cend() const { return end(); } - T *real_begin() { return _storage; } - const T *real_begin() const { return _storage; } - void swap(device_vector &v) { - ::std::swap(_size, v._size); - ::std::swap(_capacity, v._capacity); - ::std::swap(_storage, v._storage); - ::std::swap(_alloc, v._alloc); - } - reference operator[](size_type n) { return _storage[n]; } - const_reference operator[](size_type n) const { return _storage[n]; } - void reserve(size_type n) { - if (n > capacity()) { - // allocate buffer for new size - auto tmp = _alloc.allocate(2 * n); - // copy content (old buffer to new buffer) - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - begin(), end(), tmp); - // deallocate old memory - _alloc.deallocate(_storage, _capacity); - _storage = tmp; - _capacity = 2 * n; - } - } - void resize(size_type new_size, const T &x = T()) { - reserve(new_size); - if (_size < new_size) { - ::std::fill(oneapi::dpl::execution::make_device_policy(get_default_queue()), - begin() + _size, begin() + new_size, x); - } - _size = new_size; - } - size_type max_size(void) const { - return ::std::numeric_limits::max() / sizeof(T); - } - size_type capacity() const { return _capacity; } - const_reference front() const { return *begin(); } - reference front() { return *begin(); } - const_reference back(void) const { return *(end() - 1); } - reference back(void) { return *(end() - 1); } - pointer data(void) { return _storage; } - const_pointer data(void) const { return _storage; } - void shrink_to_fit(void) { - if (_size != capacity()) { - size_type tmp_capacity = ::std::max(_size, _min_capacity()); - auto tmp = _alloc.allocate(tmp_capacity); - if (_size > 0) { - ::std::copy( - oneapi::dpl::execution::make_device_policy(get_default_queue()), - begin(), end(), tmp); - } - _alloc.deallocate(_storage, _capacity); - _storage = tmp; - _capacity = tmp_capacity; - } - } - void assign(size_type n, const T &x) { - resize(n); - if (_size > 0) { - ::std::fill(oneapi::dpl::execution::make_device_policy(get_default_queue()), - begin(), begin() + n, x); - } - } - template - void - assign(InputIterator first, - typename ::std::enable_if::value, - InputIterator>::type last) { - auto n = ::std::distance(first, last); - resize(n); - if (_size > 0) { - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, begin()); - } - } - void clear(void) { _size = 0; } - bool empty(void) const { return (size() == 0); } - void push_back(const T &x) { insert(end(), size_type(1), x); } - void pop_back(void) { - if (_size > 0) - --_size; - } - iterator erase(iterator first, iterator last) { - auto n = ::std::distance(first, last); - if (last == end()) { - _size = _size - n; - return end(); - } - auto m = ::std::distance(last, end()); - if (m <= 0) { - return end(); - } - auto tmp = _alloc.allocate(m); - // copy remainder to temporary buffer. - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - last, end(), tmp); - // override (erase) subsequence in storage. - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - tmp, tmp + m, first); - _alloc.deallocate(tmp, m); - _size -= n; - return begin() + first.get_idx() + n; - } - iterator erase(iterator pos) { return erase(pos, pos + 1); } - iterator insert(iterator position, const T &x) { - auto n = ::std::distance(begin(), position); - insert(position, size_type(1), x); - return begin() + n; - } - void insert(iterator position, size_type n, const T &x) { - if (position == end()) { - resize(size() + n); - ::std::fill(oneapi::dpl::execution::make_device_policy(get_default_queue()), - end() - n, end(), x); - } else { - auto i_n = ::std::distance(begin(), position); - // allocate temporary storage - auto m = ::std::distance(position, end()); - // will throw if position is not inside active vector - auto tmp = _alloc.allocate(m); - // copy remainder - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - position, end(), tmp); - - resize(size() + n); - // resizing might invalidate position - position = begin() + position.get_idx(); - - ::std::fill(oneapi::dpl::execution::make_device_policy(get_default_queue()), - position, position + n, x); - - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - tmp, tmp + m, position + n); - _alloc.deallocate(tmp, m); - } - } - template - void - insert(iterator position, InputIterator first, - typename ::std::enable_if::value, - InputIterator>::type last) { - auto n = ::std::distance(first, last); - if (position == end()) { - resize(size() + n); - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, end()); - } else { - auto m = ::std::distance(position, end()); - // will throw if position is not inside active vector - auto tmp = _alloc.allocate(m); - - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - position, end(), tmp); - - resize(size() + n); - // resizing might invalidate position - position = begin() + position.get_idx(); - - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, position); - ::std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - tmp, tmp + m, position + n); - _alloc.deallocate(tmp, m); - } - } - Allocator get_allocator() const { return _alloc; } -}; - -#else - -template > -class device_vector { - static_assert( - std::is_same>::value, - "device_vector doesn't support custom allocator when USM is not used."); - -public: - using iterator = device_iterator; - using const_iterator = const iterator; - using reference = device_reference; - using const_reference = const reference; - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using difference_type = - typename std::iterator_traits::difference_type; - using size_type = std::size_t; - -private: - using Buffer = sycl::buffer; - using Range = sycl::range<1>; - // Using mem_mgr to handle memory allocation - void *_storage; - size_type _size; - - size_type _min_capacity() const { return size_type(1); } - - void *alloc_store(size_type num_bytes) { - return detail::mem_mgr::instance().mem_alloc(num_bytes); - } - -public: - template operator std::vector() const { - auto __tmp = std::vector(this->size()); - std::copy(oneapi::dpl::execution::dpcpp_default, this->begin(), this->end(), - __tmp.begin()); - return __tmp; - } - device_vector() - : _storage(alloc_store(_min_capacity() * sizeof(T))), _size(0) {} - ~device_vector() = default; - explicit device_vector(size_type n) : device_vector(n, T()) {} - explicit device_vector(size_type n, const T &value) - : _storage(alloc_store(std::max(n, _min_capacity()) * sizeof(T))), - _size(n) { - auto buf = get_buffer(); - std::fill(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(buf), - oneapi::dpl::begin(buf) + n, T(value)); - } - device_vector(const device_vector &other) - : _storage(other._storage), _size(other.size()) {} - device_vector(device_vector &&other) - : _storage(std::move(other._storage)), _size(other.size()) {} - - template - device_vector(InputIterator first, - typename std::enable_if< - internal::is_iterator::value && - !std::is_pointer::value && - std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - InputIterator>::type last) - : _storage(alloc_store(std::distance(first, last) * sizeof(T))), - _size(std::distance(first, last)) { - auto buf = get_buffer(); - auto dst = oneapi::dpl::begin(buf); - std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - first, last, dst); - } - - template - device_vector(InputIterator first, - typename std::enable_if::value, - InputIterator>::type last) - : _storage(alloc_store(std::distance(first, last) * sizeof(T))), - _size(std::distance(first, last)) { - auto buf = get_buffer(); - Buffer tmp_buf(first, last); - auto start = oneapi::dpl::begin(tmp_buf); - auto end = oneapi::dpl::end(tmp_buf); - auto dst = oneapi::dpl::begin(buf); - std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - start, end, dst); - } - - template - device_vector(InputIterator first, - typename std::enable_if< - internal::is_iterator::value && - !std::is_same::iterator_category, - std::random_access_iterator_tag>::value, - InputIterator>::type last) - : _storage(alloc_store(std::distance(first, last) * sizeof(T))), - _size(std::distance(first, last)) { - auto buf = get_buffer(); - std::vector tmp(first, last); - Buffer tmp_buf(tmp); - auto start = oneapi::dpl::begin(tmp_buf); - auto end = oneapi::dpl::end(tmp_buf); - auto dst = oneapi::dpl::begin(buf); - std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - start, end, dst); - } - - template - device_vector(const device_vector &v) - : _storage(alloc_store(v.size() * sizeof(T))), _size(v.size()) { - auto buf = get_buffer(); - auto dst = oneapi::dpl::begin(buf); - std::copy(oneapi::dpl::execution::make_device_policy(get_default_queue()), - v.real_begin(), v.real_begin() + v.size(), dst); - } - - template - device_vector(std::vector &v) - : _storage(alloc_store(v.size() * sizeof(T))), _size(v.size()) { - std::copy(oneapi::dpl::execution::dpcpp_default, v.begin(), v.end(), - oneapi::dpl::begin(get_buffer())); - } - - device_vector &operator=(const device_vector &other) { - // Copy assignment operator: - _size = other.size(); - void *tmp = alloc_store(_size * sizeof(T)); - auto tmp_buf = - detail::mem_mgr::instance() - .translate_ptr(tmp) - .buffer.template reinterpret(sycl::range<1>(_size)); - std::copy(oneapi::dpl::execution::dpcpp_default, - oneapi::dpl::begin(other.get_buffer()), - oneapi::dpl::end(other.get_buffer()), - oneapi::dpl::begin(tmp_buf)); - detail::mem_mgr::instance().mem_free(_storage); - _storage = tmp; - return *this; - } - device_vector &operator=(device_vector &&other) { - // Move assignment operator: - _size = other.size(); - this->_storage = std::move(other._storage); - return *this; - } - template - device_vector &operator=(const std::vector &v) { - Buffer data(v.begin(), v.end()); - _size = v.size(); - void *tmp = alloc_store(_size * sizeof(T)); - auto tmp_buf = - detail::mem_mgr::instance() - .translate_ptr(tmp) - .buffer.template reinterpret(sycl::range<1>(_size)); - std::copy(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(data), - oneapi::dpl::end(data), oneapi::dpl::begin(tmp_buf)); - detail::mem_mgr::instance().mem_free(_storage); - _storage = tmp; - - return *this; - } - Buffer get_buffer() const { - return detail::mem_mgr::instance() - .translate_ptr(_storage) - .buffer.template reinterpret(sycl::range<1>(capacity())); - } - size_type size() const { return _size; } - iterator begin() noexcept { return device_iterator(get_buffer(), 0); } - iterator end() { return device_iterator(get_buffer(), _size); } - const_iterator begin() const noexcept { - return device_iterator(get_buffer(), 0); - } - const_iterator cbegin() const noexcept { return begin(); } - const_iterator end() const { return device_iterator(get_buffer(), _size); } - const_iterator cend() const { return end(); } - T *real_begin() { - return (detail::mem_mgr::instance() - .translate_ptr(_storage) - .buffer.template get_access()) - .get_pointer(); - } - const T *real_begin() const { - return const_cast(this) - ->detail::mem_mgr::instance() - .translate_ptr(_storage) - .buffer.template get_access() - .get_pointer(); - } - void swap(device_vector &v) { - void *temp = v._storage; - v._storage = this->_storage; - this->_storage = temp; - std::swap(_size, v._size); - } - reference operator[](size_type n) { return *(begin() + n); } - const_reference operator[](size_type n) const { return *(begin() + n); } - void reserve(size_type n) { - if (n > capacity()) { - // create new buffer (allocate for new size) - void *a = alloc_store(n * sizeof(T)); - - // copy content (old buffer to new buffer) - if (_storage != nullptr) { - auto tmp = detail::mem_mgr::instance() - .translate_ptr(a) - .buffer.template reinterpret(sycl::range<1>(n)); - auto src_buf = get_buffer(); - std::copy(oneapi::dpl::execution::dpcpp_default, - oneapi::dpl::begin(src_buf), oneapi::dpl::end(src_buf), - oneapi::dpl::begin(tmp)); - - // deallocate old memory - detail::mem_mgr::instance().mem_free(_storage); - } - _storage = a; - } - } - void resize(size_type new_size, const T &x = T()) { - reserve(new_size); - if (_size < new_size) { - auto src_buf = get_buffer(); - std::fill(oneapi::dpl::execution::dpcpp_default, - oneapi::dpl::begin(src_buf) + _size, - oneapi::dpl::begin(src_buf) + new_size, x); - } - _size = new_size; - } - size_type max_size(void) const { - return std::numeric_limits::max() / sizeof(T); - } - size_type capacity() const { - return _storage != nullptr ? detail::mem_mgr::instance() - .translate_ptr(_storage) - .buffer.size() / - sizeof(T) - : 0; - } - const_reference front() const { return *begin(); } - reference front() { return *begin(); } - const_reference back(void) const { return *(end() - 1); } - reference back(void) { return *(end() - 1); } - pointer data(void) { return reinterpret_cast(_storage); } - const_pointer data(void) const { - return reinterpret_cast(_storage); - } - void shrink_to_fit(void) { - if (_size != capacity()) { - void *a = alloc_store(_size * sizeof(T)); - auto tmp = detail::mem_mgr::instance() - .translate_ptr(a) - .buffer.template reinterpret(sycl::range<1>(_size)); - std::copy(oneapi::dpl::execution::dpcpp_default, - oneapi::dpl::begin(get_buffer()), - oneapi::dpl::begin(get_buffer()) + _size, - oneapi::dpl::begin(tmp)); - detail::mem_mgr::instance().mem_free(_storage); - _storage = a; - } - } - void assign(size_type n, const T &x) { - resize(n); - std::fill(oneapi::dpl::execution::dpcpp_default, begin(), begin() + n, x); - } - template - void - assign(InputIterator first, - typename std::enable_if::value, - InputIterator>::type last) { - auto n = std::distance(first, last); - resize(n); - if (internal::is_iterator::value && - !std::is_pointer::value) - std::copy(oneapi::dpl::execution::dpcpp_default, first, last, begin()); - else { - Buffer tmp(first, last); - std::copy(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(tmp), - oneapi::dpl::end(tmp), begin()); - } - } - void clear(void) { - _size = 0; - detail::mem_mgr::instance().mem_free(_storage); - _storage = nullptr; - } - bool empty(void) const { return (size() == 0); } - void push_back(const T &x) { insert(end(), size_type(1), x); } - void pop_back(void) { - if (_size > 0) - --_size; - } - iterator erase(iterator first, iterator last) { - auto n = std::distance(first, last); - if (last == end()) { - _size = _size - n; - return end(); - } - Buffer tmp{Range(std::distance(last, end()))}; - // copy remainder to temporary buffer. - std::copy(oneapi::dpl::execution::dpcpp_default, last, end(), - oneapi::dpl::begin(tmp)); - // override (erase) subsequence in storage. - std::copy(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(tmp), - oneapi::dpl::end(tmp), first); - resize(_size - n); - return begin() + first.get_idx() + n; - } - iterator erase(iterator pos) { return erase(pos, pos + 1); } - iterator insert(iterator position, const T &x) { - auto n = std::distance(begin(), position); - insert(position, size_type(1), x); - return begin() + n; - } - void insert(iterator position, size_type n, const T &x) { - if (position == end()) { - resize(size() + n); - std::fill(oneapi::dpl::execution::dpcpp_default, end() - n, end(), x); - } else { - auto i_n = std::distance(begin(), position); - // allocate temporary storage - Buffer tmp{Range(std::distance(position, end()))}; - // copy remainder - std::copy(oneapi::dpl::execution::dpcpp_default, position, end(), - oneapi::dpl::begin(tmp)); - - resize(size() + n); - // resizing might invalidate position - position = begin() + position.get_idx(); - - std::fill(oneapi::dpl::execution::dpcpp_default, position, position + n, - x); - - std::copy(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(tmp), - oneapi::dpl::end(tmp), position + n); - } - } - template - void - insert(iterator position, InputIterator first, - typename std::enable_if::value, - InputIterator>::type last) { - auto n = std::distance(first, last); - if (position == end()) { - resize(size() + n); - std::copy(oneapi::dpl::execution::dpcpp_default, first, last, end()); - } else { - Buffer tmp{Range(std::distance(position, end()))}; - - std::copy(oneapi::dpl::execution::dpcpp_default, position, end(), - oneapi::dpl::begin(tmp)); - - resize(size() + n); - // resizing might invalidate position - position = begin() + position.get_idx(); - - std::copy(oneapi::dpl::execution::dpcpp_default, first, last, position); - std::copy(oneapi::dpl::execution::dpcpp_default, oneapi::dpl::begin(tmp), - oneapi::dpl::end(tmp), position + n); - } - } -}; - -#endif - -} // end namespace dpct - -#endif diff --git a/dpct/dpl_utils.hpp b/dpct/dpl_utils.hpp deleted file mode 100644 index 79a6e7404..000000000 --- a/dpct/dpl_utils.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//==---- dpl_utils.hpp ----------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_DPL_UTILS_HPP__ -#define __DPCT_DPL_UTILS_HPP__ - -#define ONEDPL_USE_DPCPP_BACKEND 1 -#define __USE_DPCT 1 - -#include -#include -#include - -#include "dpl_extras/memory.h" -#include "dpl_extras/algorithm.h" -#include "dpl_extras/numeric.h" -#include "dpl_extras/iterators.h" -#include "dpl_extras/vector.h" -#include "dpl_extras/dpcpp_extensions.h" - -#endif // __DPCT_DPL_UTILS_HPP__ diff --git a/dpct/fft_utils.hpp b/dpct/fft_utils.hpp deleted file mode 100644 index cba1b253c..000000000 --- a/dpct/fft_utils.hpp +++ /dev/null @@ -1,1376 +0,0 @@ -//==---- fft_utils.hpp ----------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_FFT_UTILS_HPP__ -#define __DPCT_FFT_UTILS_HPP__ - -#include -#include -#include -#include -#include "lib_common_utils.hpp" - -namespace dpct { -namespace fft { -/// An enumeration type to describe the FFT direction is forward or backward. -enum fft_direction : int { - forward = 0, - backward -}; -/// An enumeration type to describe the types of FFT input and output data. -enum fft_type : int { - real_float_to_complex_float = 0, - complex_float_to_real_float, - real_double_to_complex_double, - complex_double_to_real_double, - complex_float_to_complex_float, - complex_double_to_complex_double, -}; - -/// A class to perform FFT calculation. -class fft_engine { -public: - /// Default constructor. - fft_engine() {} - /// Commit the configuration to calculate n-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] input_type Input data type. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] output_type Output data type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int dim, long long *n, - long long *inembed, long long istride, long long idist, - library_data_t input_type, long long *onembed, long long ostride, - long long odist, library_data_t output_type, long long batch, - size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - _q = exec_queue; - init(dim, n, inembed, istride, idist, input_type, onembed, - ostride, odist, output_type, batch, - direction_and_placement); - if (scratchpad_size) { - if (_is_estimate_call) - *scratchpad_size = _workspace_estimate_bytes; - else - *scratchpad_size = _workspace_bytes; - } - } - /// Commit the configuration to calculate n-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] input_type Input data type. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] output_type Output data type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int dim, int *n, int *inembed, - int istride, int idist, library_data_t input_type, int *onembed, - int ostride, int odist, library_data_t output_type, int batch, - size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - _q = exec_queue; - init(dim, n, inembed, istride, idist, input_type, onembed, ostride, - odist, output_type, batch, direction_and_placement); - if (scratchpad_size) { - if (_is_estimate_call) - *scratchpad_size = _workspace_estimate_bytes; - else - *scratchpad_size = _workspace_bytes; - } - } - /// Commit the configuration to calculate n-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int dim, long long *n, - long long *inembed, long long istride, long long idist, - long long *onembed, long long ostride, long long odist, - fft_type type, long long batch, size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - commit(exec_queue, dim, n, inembed, istride, idist, - fft_type_to_data_type(type).first, onembed, ostride, odist, - fft_type_to_data_type(type).second, batch, scratchpad_size, - direction_and_placement); - } - /// Commit the configuration to calculate n-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int dim, int *n, int *inembed, - int istride, int idist, int *onembed, int ostride, int odist, - fft_type type, int batch, size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - commit(exec_queue, dim, n, inembed, istride, idist, - fft_type_to_data_type(type).first, onembed, ostride, odist, - fft_type_to_data_type(type).second, batch, scratchpad_size, - direction_and_placement); - } - /// Commit the configuration to calculate 1-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n1 The size of the dimension of the data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int n1, fft_type type, int batch, - size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - _q = exec_queue; - _n.resize(1); - _n[0] = n1; - std::tie(_input_type, _output_type) = fft_type_to_data_type(type); - _dim = 1; - _batch = batch; - _is_basic = true; - if (direction_and_placement.has_value()) { - _is_user_specified_dir_and_placement = true; - _direction = direction_and_placement->first; - _is_inplace = direction_and_placement->second; - } - config_and_commit_basic(); - if (scratchpad_size) { - if (_is_estimate_call) - *scratchpad_size = _workspace_estimate_bytes; - else - *scratchpad_size = _workspace_bytes; - } - } - /// Commit the configuration to calculate 2-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n2 The size of the 2nd dimension (outermost) of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int n2, int n1, fft_type type, - size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - _q = exec_queue; - _n.resize(2); - _n[0] = n2; - _n[1] = n1; - std::tie(_input_type, _output_type) = fft_type_to_data_type(type); - _dim = 2; - _is_basic = true; - if (direction_and_placement.has_value()) { - _is_user_specified_dir_and_placement = true; - _direction = direction_and_placement->first; - _is_inplace = direction_and_placement->second; - } - config_and_commit_basic(); - if (scratchpad_size) { - if (_is_estimate_call) - *scratchpad_size = _workspace_estimate_bytes; - else - *scratchpad_size = _workspace_bytes; - } - } - /// Commit the configuration to calculate 3-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n3 The size of the 3rd dimension (outermost) of the data. - /// \param [in] n2 The size of the 2nd dimension of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [out] scratchpad_size The workspace size required for this FFT. - /// If this value is used to allocate memory, \p direction_and_placement need - /// to be specified explicitly to get correct result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - void commit(sycl::queue *exec_queue, int n3, int n2, int n1, fft_type type, - size_t *scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - _q = exec_queue; - _n.resize(3); - _n[0] = n3; - _n[1] = n2; - _n[2] = n1; - std::tie(_input_type, _output_type) = fft_type_to_data_type(type); - _dim = 3; - _is_basic = true; - if (direction_and_placement.has_value()) { - _is_user_specified_dir_and_placement = true; - _direction = direction_and_placement->first; - _is_inplace = direction_and_placement->second; - } - config_and_commit_basic(); - if (scratchpad_size) { - if (_is_estimate_call) - *scratchpad_size = _workspace_estimate_bytes; - else - *scratchpad_size = _workspace_bytes; - } - } - - /// Create the class for calculate 1-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n1 The size of the dimension of the data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - static fft_engine * - create(sycl::queue *exec_queue, int n1, fft_type type, int batch, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = new fft_engine(); - engine->commit(exec_queue, n1, type, batch, nullptr, - direction_and_placement); - return engine; - } - /// Create the class for calculate 2-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n2 The size of the 2nd dimension (outermost) of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - static fft_engine * - create(sycl::queue *exec_queue, int n2, int n1, fft_type type, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = new fft_engine(); - engine->commit(exec_queue, n2, n1, type, nullptr, direction_and_placement); - return engine; - } - /// Create the class for calculate 3-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] n3 The size of the 3rd dimension (outermost) of the data. - /// \param [in] n2 The size of the 2nd dimension of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - static fft_engine * - create(sycl::queue *exec_queue, int n3, int n2, int n1, fft_type type, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = new fft_engine(); - engine->commit(exec_queue, n3, n2, n1, type, nullptr, - direction_and_placement); - return engine; - } - /// Create the class for calculate n-D FFT. - /// \param [in] exec_queue The queue where the calculation should be executed. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If this value is specified, the direction parameter - /// will be ignored in the fft_engine::compute function. If it is not set, - /// forward direction(if current FFT is complex-to-complex) and out-of-place - /// (false) are set by default. - static fft_engine * - create(sycl::queue *exec_queue, int dim, int *n, int *inembed, int istride, - int idist, int *onembed, int ostride, int odist, fft_type type, - int batch, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = new fft_engine(); - engine->commit(exec_queue, dim, n, inembed, istride, idist, onembed, - ostride, odist, type, batch, nullptr, - direction_and_placement); - return engine; - } - /// Create the class for calculate FFT without commit any config. - static fft_engine *create() { - fft_engine *engine = new fft_engine(); - return engine; - } - /// Destroy the class for calculate FFT. - /// \param [in] engine Pointer returned from fft_engine::craete. - static void destroy(fft_engine *engine) { delete engine; } - -#ifdef __INTEL_MKL__ - /// Estimates the workspace size for calculating n-D FFT. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] estimated_scratchpad_size The estimated workspace size - /// required for this FFT. If this value is used to allocate memory, - /// \p direction_and_placement need to be specified explicitly to get correct - /// result. - /// \param [in] direction_and_placement Explicitly specify the FFT - /// direction and placement info. If it is not set, forward direction(if - /// current FFT is complex-to-complex) and out-of-place (false) are set by default. - static void - estimate_size(int dim, long long *n, long long *inembed, long long istride, - long long idist, long long *onembed, long long ostride, - long long odist, fft_type type, long long batch, - size_t *estimated_scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = fft_engine::create(); - engine->_is_estimate_call = true; - engine->commit(&dpct::get_default_queue(), dim, n, inembed, istride, idist, - fft_type_to_data_type(type).first, onembed, ostride, odist, - fft_type_to_data_type(type).second, batch, - estimated_scratchpad_size, direction_and_placement); - fft_engine::destroy(engine); - } - /// Estimates the workspace size for calculating n-D FFT. - /// \param [in] dim Dimension number of the data. - /// \param [in] n Pointer to an array containing each dimension's size. - /// \param [in] inembed Pointer to an array containing each dimension's size - /// of the embedded input data. - /// \param [in] istride Stride size of the input data. - /// \param [in] idist Distance between the two batches of the input data. - /// \param [in] onembed Pointer to an array containing each dimension's size - /// of the embedded output data. - /// \param [in] ostride Stride size of the output data. - /// \param [in] odist Distance between the two batches of the output data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] estimated_scratchpad_size The estimated workspace size - /// required for this FFT. If this value is used to allocate memory, - /// \p direction_and_placement need to be specified explicitly to get correct - /// result. - /// \param [in] direction_and_placement Explicitly specify the FFT - /// direction and placement info. If it is not set, forward direction(if - /// current FFT is complex-to-complex) and out-of-place (false) are set by default. - static void - estimate_size(int dim, int *n, int *inembed, int istride, int idist, - int *onembed, int ostride, int odist, fft_type type, int batch, - size_t *estimated_scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = fft_engine::create(); - engine->_is_estimate_call = true; - engine->commit(&dpct::get_default_queue(), dim, n, inembed, istride, idist, - fft_type_to_data_type(type).first, onembed, ostride, odist, - fft_type_to_data_type(type).second, batch, - estimated_scratchpad_size, direction_and_placement); - fft_engine::destroy(engine); - } - /// Estimates the workspace size for calculating 1-D FFT. - /// \param [in] n1 The size of the dimension of the data. - /// \param [in] type The FFT type. - /// \param [in] batch The number of FFT operations to perform. - /// \param [out] estimated_scratchpad_size The estimated workspace size - /// required for this FFT. If this value is used to allocate memory, - /// \p direction_and_placement need to be specified explicitly to get correct - /// result. - /// \param [in] direction_and_placement Explicitly specify the FFT direction - /// and placement info. If it is not set, forward direction(if current FFT is - /// complex-to-complex) and out-of-place (false) are set by default. - static void - estimate_size(int n1, fft_type type, int batch, - size_t *estimated_scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = fft_engine::create(); - engine->_is_estimate_call = true; - engine->commit(&dpct::get_default_queue(), n1, type, batch, - estimated_scratchpad_size, direction_and_placement); - fft_engine::destroy(engine); - } - /// Estimates the workspace size for calculating 2-D FFT. - /// \param [in] n2 The size of the 2nd dimension (outermost) of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [out] estimated_scratchpad_size The estimated workspace size - /// required for this FFT. If this value is used to allocate memory, - /// \p direction_and_placement need to be specified explicitly to get correct - /// result. - /// \param [in] direction_and_placement Explicitly specify the FFT - /// direction and placement info. If it is not set, forward direction(if - /// current FFT is complex-to-complex) and out-of-place (false) are set by default. - static void - estimate_size(int n2, int n1, fft_type type, - size_t *estimated_scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = fft_engine::create(); - engine->_is_estimate_call = true; - engine->commit(&dpct::get_default_queue(), n2, n1, type, - estimated_scratchpad_size, direction_and_placement); - fft_engine::destroy(engine); - } - /// Estimates the workspace size for calculating 3-D FFT. - /// \param [in] n3 The size of the 3rd dimension (outermost) of the data. - /// \param [in] n2 The size of the 2nd dimension of the data. - /// \param [in] n1 The size of the 1st dimension (innermost) of the data. - /// \param [in] type The FFT type. - /// \param [out] estimated_scratchpad_size The estimated workspace size - /// required for this FFT. If this value is used to allocate memory, - /// \p direction_and_placement need to be specified explicitly to get correct - /// result. - /// \param [in] direction_and_placement Explicitly specify the FFT - /// direction and placement info. If it is not set, forward direction(if - /// current FFT is complex-to-complex) and out-of-place (false) are set by default. - static void - estimate_size(int n3, int n2, int n1, fft_type type, - size_t *estimated_scratchpad_size, - std::optional> - direction_and_placement = std::nullopt) { - fft_engine *engine = fft_engine::create(); - engine->_is_estimate_call = true; - engine->commit(&dpct::get_default_queue(), n3, n2, n1, type, - estimated_scratchpad_size, direction_and_placement); - fft_engine::destroy(engine); - } -#endif - - /// Execute the FFT calculation. - /// \param [in] input Pointer to the input data. - /// \param [out] output Pointer to the output data. - /// \param [in] direction The FFT direction. - template - void compute(input_t *input, output_t *output, fft_direction direction) { - if (_input_type == library_data_t::complex_float && - _output_type == library_data_t::complex_float) { - compute_complex( - (float *)input, (float *)output, direction); - } else if (_input_type == library_data_t::complex_double && - _output_type == library_data_t::complex_double) { - compute_complex( - (double *)input, (double *)output, direction); - } else if (_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) { - _direction = direction; - compute_real((float *)input, - (float *)output); - } else if (_input_type == library_data_t::complex_float && - _output_type == library_data_t::real_float) { - _direction = direction; - compute_real((float *)input, - (float *)output); - } else if (_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double) { - _direction = direction; - compute_real( - (double *)input, (double *)output); - } else if (_input_type == library_data_t::complex_double && - _output_type == library_data_t::real_double) { - _direction = direction; - compute_real( - (double *)input, (double *)output); - } - } - template <> - void compute(float *input, sycl::float2 *output, fft_direction direction) { - _direction = direction; - compute_real((float *)input, - (float *)output); - } - template <> - void compute(sycl::float2 *input, float *output, fft_direction direction) { - _direction = direction; - compute_real((float *)input, - (float *)output); - } - template <> - void compute(double *input, sycl::double2 *output, fft_direction direction) { - _direction = direction; - compute_real((double *)input, - (double *)output); - } - template <> - void compute(sycl::double2 *input, double *output, fft_direction direction) { - _direction = direction; - compute_real((double *)input, - (double *)output); - } - template <> - void compute(sycl::float2 *input, sycl::float2 *output, - fft_direction direction) { - compute_complex( - (float *)input, (float *)output, direction); - } - template <> - void compute(sycl::double2 *input, sycl::double2 *output, - fft_direction direction) { - compute_complex( - (double *)input, (double *)output, direction); - } - /// Setting the user's SYCL queue for calculation. - /// \param [in] q Pointer to the SYCL queue. - void set_queue(sycl::queue *q) { _q = q; } -#ifdef __INTEL_MKL__ - /// Setting whether to use external or internal workspace. - /// \param [in] flag True means using internal workspace. False means using - /// external workspace. - void use_internal_workspace(bool flag = true) { - _use_external_workspace = !flag; - } - /// Specify the external workspace. - /// \param [in] ptr Pointer to the workspace. - void set_workspace(void *ptr) { - if (!_use_external_workspace) { - return; - } - if (_input_type == library_data_t::complex_float && - _output_type == library_data_t::complex_float) { - if (_q->get_device().is_gpu()) { - auto data = dpct::detail::get_memory(ptr); - _desc_sc->set_workspace(data); - } - } else if (_input_type == library_data_t::complex_double && - _output_type == library_data_t::complex_double) { - if (_q->get_device().is_gpu()) { - auto data = dpct::detail::get_memory(ptr); - _desc_dc->set_workspace(data); - } - } else if ((_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) || - (_input_type == library_data_t::complex_float && - _output_type == library_data_t::real_float)) { - if (_q->get_device().is_gpu()) { - auto data = dpct::detail::get_memory(ptr); - _desc_sr->set_workspace(data); - } - } else if ((_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double) || - (_input_type == library_data_t::complex_double && - _output_type == library_data_t::real_double)) { - if (_q->get_device().is_gpu()) { - auto data = dpct::detail::get_memory(ptr); - _desc_dr->set_workspace(data); - } - } else { - throw sycl::exception(sycl::make_error_code(sycl::errc::invalid), - "invalid fft type"); - } - } -#endif - /// Get the workspace size. - /// \param [out] scratchpad_size Workspace size in bytes. - void get_workspace_size(size_t *scratchpad_size) { - if (scratchpad_size) { - *scratchpad_size = _workspace_bytes; - } - } - -private: - static std::pair - fft_type_to_data_type(fft_type type) { - switch (type) { - case fft_type::real_float_to_complex_float: { - return std::make_pair(library_data_t::real_float, - library_data_t::complex_float); - } - case fft_type::complex_float_to_real_float: { - return std::make_pair(library_data_t::complex_float, - library_data_t::real_float); - } - case fft_type::real_double_to_complex_double: { - return std::make_pair(library_data_t::real_double, - library_data_t::complex_double); - } - case fft_type::complex_double_to_real_double: { - return std::make_pair(library_data_t::complex_double, - library_data_t::real_double); - } - case fft_type::complex_float_to_complex_float: { - return std::make_pair(library_data_t::complex_float, - library_data_t::complex_float); - } - case fft_type::complex_double_to_complex_double: { - return std::make_pair(library_data_t::complex_double, - library_data_t::complex_double); - } - } - } - - void config_and_commit_basic() { - if (_input_type == library_data_t::complex_float && - _output_type == library_data_t::complex_float) { - _desc_sc = std::make_shared< - oneapi::mkl::dft::descriptor>(_n); - std::int64_t distance = 1; - for (auto i : _n) - distance = distance * i; - _fwd_dist = distance; - _bwd_dist = distance; - _desc_sc->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - distance); - _desc_sc->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - distance); - _desc_sc->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - _batch); -#ifdef __INTEL_MKL__ - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - else - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - if (_use_external_workspace) { - if (_q->get_device().is_gpu()) { - _desc_sc->set_value( - oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - } - } - if (_is_estimate_call) { - if (_q->get_device().is_gpu()) { - _desc_sc->get_value( - oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES, - &_workspace_estimate_bytes); - } - } else { - _desc_sc->commit(*_q); - if (_q->get_device().is_gpu()) { - _desc_sc->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES, - &_workspace_bytes); - } - } -#else - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - else - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - _desc_sc->commit(*_q); -#endif - } else if (_input_type == library_data_t::complex_double && - _output_type == library_data_t::complex_double) { - _desc_dc = std::make_shared< - oneapi::mkl::dft::descriptor>(_n); - std::int64_t distance = 1; - for (auto i : _n) - distance = distance * i; - _fwd_dist = distance; - _bwd_dist = distance; - _desc_dc->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - distance); - _desc_dc->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - distance); - _desc_dc->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - _batch); -#ifdef __INTEL_MKL__ - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - else - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - if (_use_external_workspace) { - if (_q->get_device().is_gpu()) { - _desc_dc->set_value( - oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - } - } - if (_is_estimate_call) { - if (_q->get_device().is_gpu()) { - _desc_dc->get_value( - oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES, - &_workspace_estimate_bytes); - } - } else { - _desc_dc->commit(*_q); - if (_q->get_device().is_gpu()) { - _desc_dc->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES, - &_workspace_bytes); - } - } -#else - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - else - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - _desc_dc->commit(*_q); -#endif - } else if ((_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) || - (_input_type == library_data_t::complex_float && - _output_type == library_data_t::real_float)) { - _desc_sr = std::make_shared>( - _n); - if (_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) - _direction = fft_direction::forward; - else - _direction = fft_direction::backward; - _desc_sr->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - _batch); -#ifdef __INTEL_MKL__ - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - set_stride_and_distance_basic(_desc_sr); - } else { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - set_stride_and_distance_basic(_desc_sr); - } - if (_use_external_workspace) { - if (_q->get_device().is_gpu()) { - _desc_sr->set_value( - oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - } - } - if (_is_estimate_call) { - if (_q->get_device().is_gpu()) { - _desc_sr->get_value( - oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES, - &_workspace_estimate_bytes); - } - } else { - _desc_sr->commit(*_q); - if (_q->get_device().is_gpu()) { - _desc_sr->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES, - &_workspace_bytes); - } - } -#else - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - set_stride_and_distance_basic(_desc_sr); - } else { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - set_stride_and_distance_basic(_desc_sr); - } - _desc_sr->commit(*_q); -#endif - } else if ((_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double) || - (_input_type == library_data_t::complex_double && - _output_type == library_data_t::real_double)) { - _desc_dr = std::make_shared>( - _n); - if (_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double) - _direction = fft_direction::forward; - else - _direction = fft_direction::backward; - _desc_dr->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - _batch); -#ifdef __INTEL_MKL__ - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - set_stride_and_distance_basic(_desc_dr); - } else { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - set_stride_and_distance_basic(_desc_dr); - } - if (_use_external_workspace) { - if (_q->get_device().is_gpu()) { - _desc_dr->set_value( - oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - } - } - if (_is_estimate_call) { - if (_q->get_device().is_gpu()) { - _desc_dr->get_value( - oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES, - &_workspace_estimate_bytes); - } - } else { - _desc_dr->commit(*_q); - if (_q->get_device().is_gpu()) { - _desc_dr->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES, - &_workspace_bytes); - } - } -#else - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - set_stride_and_distance_basic(_desc_dr); - } else { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - set_stride_and_distance_basic(_desc_dr); - } - _desc_dr->commit(*_q); -#endif - } else { - throw sycl::exception(sycl::make_error_code(sycl::errc::invalid), - "invalid fft type"); - } - } - - void config_and_commit_advanced() { -#ifdef __INTEL_MKL__ -#define CONFIG_AND_COMMIT(DESC, PREC, DOM, TYPE) \ - { \ - DESC = std::make_shared>( \ - _n); \ - set_stride_advanced(DESC); \ - DESC->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _fwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, \ - _batch); \ - if (_is_user_specified_dir_and_placement && _is_inplace) \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - DFTI_CONFIG_VALUE::DFTI_INPLACE); \ - else \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); \ - if (_use_external_workspace) { \ - DESC->set_value(oneapi::mkl::dft::config_param::WORKSPACE, \ - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); \ - } \ - if (_is_estimate_call) { \ - if (_q->get_device().is_gpu()) { \ - DESC->get_value( \ - oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES, \ - &_workspace_estimate_bytes); \ - } \ - } else { \ - DESC->commit(*_q); \ - if (_is_estimate_call) { \ - DESC->get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES, \ - &_workspace_bytes); \ - } \ - } \ - } -#else -#define CONFIG_AND_COMMIT(DESC, PREC, DOM, TYPE) \ - { \ - DESC = std::make_shared>( \ - _n); \ - set_stride_advanced(DESC); \ - DESC->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _fwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, \ - _batch); \ - if (_is_user_specified_dir_and_placement && _is_inplace) \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - oneapi::mkl::dft::config_value::INPLACE); \ - else \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - oneapi::mkl::dft::config_value::NOT_INPLACE); \ - DESC->commit(*_q); \ - } -#endif - - if (_input_type == library_data_t::complex_float && - _output_type == library_data_t::complex_float) { - CONFIG_AND_COMMIT(_desc_sc, SINGLE, COMPLEX, float); - } else if (_input_type == library_data_t::complex_double && - _output_type == library_data_t::complex_double) { - CONFIG_AND_COMMIT(_desc_dc, DOUBLE, COMPLEX, double); - } else if ((_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) || - (_input_type == library_data_t::complex_float && - _output_type == library_data_t::real_float)) { - CONFIG_AND_COMMIT(_desc_sr, SINGLE, REAL, float); - } else if ((_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double) || - (_input_type == library_data_t::complex_double && - _output_type == library_data_t::real_double)) { - CONFIG_AND_COMMIT(_desc_dr, DOUBLE, REAL, double); - } else { - throw sycl::exception(sycl::make_error_code(sycl::errc::invalid), - "invalid fft type"); - } -#undef CONFIG_AND_COMMIT - } - - template - void init(int dim, T *n, T *inembed, T istride, T idist, - library_data_t input_type, T *onembed, T ostride, T odist, - library_data_t output_type, T batch, - std::optional> - direction_and_placement) { - if (direction_and_placement.has_value()) { - _is_user_specified_dir_and_placement = true; - _direction = direction_and_placement->first; - _is_inplace = direction_and_placement->second; - } - _n.resize(dim); - _inembed.resize(dim); - _onembed.resize(dim); - _input_type = input_type; - _output_type = output_type; - for (int i = 0; i < dim; i++) { - _n[i] = n[i]; - } - if (inembed && onembed) { - for (int i = 0; i < dim; i++) { - _inembed[i] = inembed[i]; - _onembed[i] = onembed[i]; - } - _istride = istride; - _ostride = ostride; - - if ((_input_type == library_data_t::real_float && - _output_type == library_data_t::complex_float) || - (_input_type == library_data_t::real_double && - _output_type == library_data_t::complex_double)) { - _fwd_dist = idist; - _bwd_dist = odist; - } else if ((_output_type == library_data_t::real_float && - _input_type == library_data_t::complex_float) || - (_output_type == library_data_t::real_double && - _input_type == library_data_t::complex_double)) { - _fwd_dist = odist; - _bwd_dist = idist; - } else { - if (_is_user_specified_dir_and_placement && - (_direction == fft_direction::backward)) { - _fwd_dist = odist; - _bwd_dist = idist; - } else { - _fwd_dist = idist; - _bwd_dist = odist; - } - } - } else { - _is_basic = true; - } - _batch = batch; - _dim = dim; - - if (_is_basic) - config_and_commit_basic(); - else - config_and_commit_advanced(); - } - template - void set_stride_advanced(std::shared_ptr desc) { - if (_dim == 1) { - std::int64_t input_stride[2] = {0, _istride}; - std::int64_t output_stride[2] = {0, _ostride}; - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - input_stride); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - output_stride); - } else if (_dim == 2) { - std::int64_t input_stride[3] = {0, _inembed[1] * _istride, _istride}; - std::int64_t output_stride[3] = {0, _onembed[1] * _ostride, _ostride}; - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - input_stride); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - output_stride); - } else if (_dim == 3) { - std::int64_t input_stride[4] = {0, _inembed[2] * _inembed[1] * _istride, - _inembed[2] * _istride, _istride}; - std::int64_t output_stride[4] = {0, _onembed[2] * _onembed[1] * _ostride, - _onembed[2] * _ostride, _ostride}; - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - input_stride); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - output_stride); - } - } - - template void swap_distance(std::shared_ptr desc) { - desc->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _bwd_dist); - desc->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _fwd_dist); - std::int64_t temp = _bwd_dist; - _bwd_dist = _fwd_dist; - _fwd_dist = temp; - } - - template - void set_stride_and_distance_basic(std::shared_ptr desc) { - std::int64_t forward_distance = 0; - std::int64_t backward_distance = 0; - -#define SET_STRIDE \ - { \ - if (_direction == fft_direction::forward) { \ - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, \ - real_stride); \ - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, \ - complex_stride); \ - } else { \ - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, \ - complex_stride); \ - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, \ - real_stride); \ - } \ - } - if (_dim == 1) { - if constexpr (Is_inplace) { - std::int64_t real_stride[2] = {0, 1}; - std::int64_t complex_stride[2] = {0, 1}; - SET_STRIDE; - forward_distance = 2 * (_n[0] / 2 + 1); - backward_distance = _n[0] / 2 + 1; - } else { - std::int64_t real_stride[2] = {0, 1}; - std::int64_t complex_stride[2] = {0, 1}; - SET_STRIDE; - forward_distance = _n[0]; - backward_distance = _n[0] / 2 + 1; - } - } else if (_dim == 2) { - if constexpr (Is_inplace) { - std::int64_t complex_stride[3] = {0, _n[1] / 2 + 1, 1}; - std::int64_t real_stride[3] = {0, 2 * (_n[1] / 2 + 1), 1}; - SET_STRIDE; - forward_distance = _n[0] * 2 * (_n[1] / 2 + 1); - backward_distance = _n[0] * (_n[1] / 2 + 1); - } else { - std::int64_t complex_stride[3] = {0, _n[1] / 2 + 1, 1}; - std::int64_t real_stride[3] = {0, _n[1], 1}; - SET_STRIDE; - forward_distance = _n[0] * _n[1]; - backward_distance = _n[0] * (_n[1] / 2 + 1); - } - } else if (_dim == 3) { - if constexpr (Is_inplace) { - std::int64_t complex_stride[4] = {0, _n[1] * (_n[2] / 2 + 1), - _n[2] / 2 + 1, 1}; - std::int64_t real_stride[4] = {0, _n[1] * 2 * (_n[2] / 2 + 1), - 2 * (_n[2] / 2 + 1), 1}; - SET_STRIDE; - forward_distance = _n[0] * _n[1] * 2 * (_n[2] / 2 + 1); - backward_distance = _n[0] * _n[1] * (_n[2] / 2 + 1); - } else { - std::int64_t complex_stride[4] = {0, _n[1] * (_n[2] / 2 + 1), - _n[2] / 2 + 1, 1}; - std::int64_t real_stride[4] = {0, _n[1] * _n[2], _n[2], 1}; - SET_STRIDE; - forward_distance = _n[0] * _n[1] * _n[2]; - backward_distance = _n[0] * _n[1] * (_n[2] / 2 + 1); - } - } -#undef SET_STRIDE - desc->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - forward_distance); - desc->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - backward_distance); - } - -#define COMPUTE(DESC) \ - { \ - if (_is_inplace) { \ - auto data_input = dpct::detail::get_memory(input); \ - if (_direction == fft_direction::forward) { \ - oneapi::mkl::dft::compute_forward< \ - std::remove_reference_t, T>(*DESC, data_input); \ - } else { \ - oneapi::mkl::dft::compute_backward< \ - std::remove_reference_t, T>(*DESC, data_input); \ - } \ - } else { \ - auto data_input = dpct::detail::get_memory(input); \ - auto data_output = dpct::detail::get_memory(output); \ - if (_direction == fft_direction::forward) { \ - oneapi::mkl::dft::compute_forward< \ - std::remove_reference_t, T, T>(*DESC, data_input, \ - data_output); \ - } else { \ - oneapi::mkl::dft::compute_backward< \ - std::remove_reference_t, T, T>(*DESC, data_input, \ - data_output); \ - } \ - } \ - } - - template - void compute_complex(T *input, T *output, fft_direction direction) { - bool is_this_compute_inplace = input == output; - - if (!_is_user_specified_dir_and_placement) { - // The complex domain descriptor need different config values if the - // FFT direction or placement is different. - // Here we check the conditions, and new config values are set and - // re-committed if needed. - if (direction != _direction || is_this_compute_inplace != _is_inplace) { - if constexpr (Precision == oneapi::mkl::dft::precision::SINGLE) { - if (direction != _direction) { - swap_distance(_desc_sc); - _direction = direction; - } - if (is_this_compute_inplace != _is_inplace) { - _is_inplace = is_this_compute_inplace; -#ifdef __INTEL_MKL__ - if (_is_inplace) { - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - } else { - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - } -#else - if (_is_inplace) { - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - } else { - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - } -#endif - } - _desc_sc->commit(*_q); - } else { - if (direction != _direction) { - swap_distance(_desc_dc); - _direction = direction; - } - if (is_this_compute_inplace != _is_inplace) { - _is_inplace = is_this_compute_inplace; -#ifdef __INTEL_MKL__ - if (_is_inplace) { - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - } else { - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - } -#else - if (_is_inplace) { - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - } else { - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - } -#endif - } - _desc_dc->commit(*_q); - } - } - } - - if constexpr (Precision == oneapi::mkl::dft::precision::SINGLE) { - COMPUTE(_desc_sc); - } else { - COMPUTE(_desc_dc); - } - } - - template - void compute_real(T *input, T *output) { - bool is_this_compute_inplace = input == output; - - if (!_is_user_specified_dir_and_placement) { - // The real domain descriptor need different config values if the - // FFT placement is different. - // Here we check the condition, and new config values are set and - // re-committed if needed. - if (is_this_compute_inplace != _is_inplace) { - if constexpr (Precision == oneapi::mkl::dft::precision::SINGLE) { - _is_inplace = is_this_compute_inplace; - if (_is_inplace) { -#ifdef __INTEL_MKL__ - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); -#else - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); -#endif - if (_is_basic) - set_stride_and_distance_basic(_desc_sr); - } else { -#ifdef __INTEL_MKL__ - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); -#else - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); -#endif - if (_is_basic) - set_stride_and_distance_basic(_desc_sr); - } - _desc_sr->commit(*_q); - } else { - _is_inplace = is_this_compute_inplace; - if (_is_inplace) { -#ifdef __INTEL_MKL__ - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); -#else - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); -#endif - if (_is_basic) - set_stride_and_distance_basic(_desc_dr); - } else { -#ifdef __INTEL_MKL__ - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); -#else - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); -#endif - if (_is_basic) - set_stride_and_distance_basic(_desc_dr); - } - _desc_dr->commit(*_q); - } - } - } - - if constexpr (Precision == oneapi::mkl::dft::precision::SINGLE) { - COMPUTE(_desc_sr); - } else { - COMPUTE(_desc_dr); - } - } -#undef COMPUTE - -private: - sycl::queue *_q = nullptr; - int _dim; - std::vector _n; - std::vector _inembed; - std::int64_t _istride; - std::int64_t _fwd_dist; - library_data_t _input_type; - std::vector _onembed; - std::int64_t _ostride; - std::int64_t _bwd_dist; - library_data_t _output_type; - std::int64_t _batch = 1; - bool _is_basic = false; - bool _is_inplace = false; - fft_direction _direction = fft_direction::forward; - bool _is_user_specified_dir_and_placement = false; - bool _use_external_workspace = false; - void *_external_workspace_ptr = nullptr; - size_t _workspace_bytes = 0; - bool _is_estimate_call = false; - size_t _workspace_estimate_bytes = 0; - std::shared_ptr> - _desc_sr; - std::shared_ptr> - _desc_dr; - std::shared_ptr> - _desc_sc; - std::shared_ptr> - _desc_dc; -}; - -using fft_engine_ptr = fft_engine *; -} // namespace fft -} // namespace dpct - -#endif // __DPCT_FFT_UTILS_HPP__ diff --git a/dpct/image.hpp b/dpct/image.hpp deleted file mode 100644 index b9bb24668..000000000 --- a/dpct/image.hpp +++ /dev/null @@ -1,901 +0,0 @@ -//==---- image.hpp --------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_IMAGE_HPP__ -#define __DPCT_IMAGE_HPP__ - -#include - -#include "memory.hpp" -#include "util.hpp" - -namespace dpct { - -enum class image_channel_data_type { - signed_int, - unsigned_int, - fp, -}; - -class image_channel; -class image_wrapper_base; -namespace detail { -/// Image object type traits, with accessor type and sampled data type defined. -/// The data type of an image accessor must be one of sycl::int4, sycl::uint4, -/// sycl::float4 and sycl::half4. The data type of accessors with 8bits/16bits -/// channel width will be 32 bits. sycl::half is an exception. -template struct image_trait { - using acc_data_t = sycl::vec; - template - using accessor_t = - sycl::accessor; - template - using array_accessor_t = - sycl::accessor; - using data_t = T; - using elem_t = T; - static constexpr image_channel_data_type data_type = - std::is_integral::value - ? (std::is_signed::value ? image_channel_data_type::signed_int - : image_channel_data_type::unsigned_int) - : image_channel_data_type::fp; - static constexpr int channel_num = 1; -}; -template <> -struct image_trait : public image_trait { - using data_t = std::uint8_t; - using elem_t = data_t; -}; -template <> -struct image_trait - : public image_trait { - using data_t = std::uint16_t; - using elem_t = data_t; -}; -template <> -struct image_trait : public image_trait { - using data_t = std::int8_t; - using elem_t = data_t; -}; -template <> -struct image_trait : public image_trait { - using data_t = std::int16_t; - using elem_t = data_t; -}; -template <> -struct image_trait - : public image_trait::value, signed char, unsigned char>::type> {}; - -template -struct image_trait> : public image_trait {}; - -template -struct image_trait> : public image_trait { - using data_t = sycl::vec; - static constexpr int channel_num = 2; -}; - -template -struct image_trait> - : public image_trait> { - static constexpr int channel_num = 3; -}; - -template -struct image_trait> : public image_trait { - using data_t = sycl::vec; - static constexpr int channel_num = 4; -}; - -/// Functor to fetch data from read result of an image accessor. -template struct fetch_data { - using return_t = typename image_trait::data_t; - using acc_data_t = typename image_trait::acc_data_t; - - return_t operator()(acc_data_t &&original_data) { - return (return_t)original_data.r(); - } -}; -template -struct fetch_data> : public fetch_data {}; -template struct fetch_data> { - using return_t = typename image_trait>::data_t; - using acc_data_t = typename image_trait>::acc_data_t; - - return_t operator()(acc_data_t &&origin_data) { - return return_t(origin_data.r(), origin_data.g()); - } -}; -template -struct fetch_data> - : public fetch_data> {}; -template struct fetch_data> { - using return_t = typename image_trait>::data_t; - using acc_data_t = typename image_trait>::acc_data_t; - - return_t operator()(acc_data_t &&origin_data) { - return return_t(origin_data.r(), origin_data.g(), origin_data.b(), - origin_data.a()); - } -}; - -/// Create image according with given type \p T and \p dims. -template static image_wrapper_base *create_image_wrapper(int dims); - -/// Create image with given data type \p T, channel order and dims -template -static image_wrapper_base *create_image_wrapper(unsigned channel_num, int dims); - -/// Create image with channel info and specified dimensions. -static image_wrapper_base *create_image_wrapper(image_channel channel, int dims); - -} // namespace detail - -/// Image channel info, include channel number, order, data width and type -class image_channel { - image_channel_data_type _type = image_channel_data_type::signed_int; - /// Number of channels. - unsigned _channel_num = 0; - /// Total size of all channels in bytes. - unsigned _total_size = 0; - /// Size of each channel in bytes. - unsigned _channel_size = 0; - -public: - /// Create image channel info according to template argument \p T. - template static image_channel create() { - image_channel channel; - channel.set_channel_size(detail::image_trait::channel_num, - sizeof(typename detail::image_trait::elem_t) * - 8); - channel.set_channel_data_type(detail::image_trait::data_type); - return channel; - } - - image_channel() = default; - - image_channel_data_type get_channel_data_type() { return _type; } - void set_channel_data_type(image_channel_data_type type) { _type = type; } - - unsigned get_total_size() { return _total_size; } - - unsigned get_channel_num() { return _channel_num; } - void set_channel_num(unsigned channel_num) { - _channel_num = channel_num; - _total_size = _channel_size * _channel_num; - } - - /// image_channel constructor. - /// \param r Channel r width in bits. - /// \param g Channel g width in bits. Should be same with \p r, or zero. - /// \param b Channel b width in bits. Should be same with \p g, or zero. - /// \param a Channel a width in bits. Should be same with \p b, or zero. - /// \param data_type Image channel data type: signed_nt, unsigned_int or fp. - image_channel(int r, int g, int b, int a, image_channel_data_type data_type) { - _type = data_type; - if (a) { - assert(r == a && "SYCL doesn't support different channel size"); - assert(r == b && "SYCL doesn't support different channel size"); - assert(r == g && "SYCL doesn't support different channel size"); - set_channel_size(4, a); - } else if (b) { - assert(r == b && "SYCL doesn't support different channel size"); - assert(r == g && "SYCL doesn't support different channel size"); - set_channel_size(3, b); - } else if (g) { - assert(r == g && "SYCL doesn't support different channel size"); - set_channel_size(2, g); - } else { - set_channel_size(1, r); - } - } - - sycl::image_channel_type get_channel_type() const { - if (_channel_size == 4) { - if (_type == image_channel_data_type::signed_int) - return sycl::image_channel_type::signed_int32; - else if (_type == image_channel_data_type::unsigned_int) - return sycl::image_channel_type::unsigned_int32; - else if (_type == image_channel_data_type::fp) - return sycl::image_channel_type::fp32; - } else if (_channel_size == 2) { - if (_type == image_channel_data_type::signed_int) - return sycl::image_channel_type::signed_int16; - else if (_type == image_channel_data_type::unsigned_int) - return sycl::image_channel_type::unsigned_int16; - else if (_type == image_channel_data_type::fp) - return sycl::image_channel_type::fp16; - } else { - if (_type == image_channel_data_type::signed_int) - return sycl::image_channel_type::signed_int8; - else if (_type == image_channel_data_type::unsigned_int) - return sycl::image_channel_type::unsigned_int8; - } - assert(false && "unexpected channel data kind and channel size"); - return sycl::image_channel_type::signed_int32; - } - void set_channel_type(sycl::image_channel_type type) { - switch (type) { - case sycl::image_channel_type::unsigned_int8: - _type = image_channel_data_type::unsigned_int; - _channel_size = 1; - break; - case sycl::image_channel_type::unsigned_int16: - _type = image_channel_data_type::unsigned_int; - _channel_size = 2; - break; - case sycl::image_channel_type::unsigned_int32: - _type = image_channel_data_type::unsigned_int; - _channel_size = 4; - break; - case sycl::image_channel_type::signed_int8: - _type = image_channel_data_type::signed_int; - _channel_size = 1; - break; - case sycl::image_channel_type::signed_int16: - _type = image_channel_data_type::signed_int; - _channel_size = 2; - break; - case sycl::image_channel_type::signed_int32: - _type = image_channel_data_type::signed_int; - _channel_size = 4; - break; - case sycl::image_channel_type::fp16: - _type = image_channel_data_type::fp; - _channel_size = 2; - break; - case sycl::image_channel_type::fp32: - _type = image_channel_data_type::fp; - _channel_size = 4; - break; - default: - break; - } - _total_size = _channel_size * _channel_num; - } - - sycl::image_channel_order get_channel_order() const { - switch (_channel_num) { - case 1: - return sycl::image_channel_order::r; - case 2: - return sycl::image_channel_order::rg; - case 3: - return sycl::image_channel_order::rgb; - case 4: - return sycl::image_channel_order::rgba; - default: - return sycl::image_channel_order::r; - } - } - /// Get the size for each channel in bits. - unsigned get_channel_size() const { return _channel_size * 8; } - - /// Set channel size. - /// \param in_channel_num Channels number to set. - /// \param channel_size Size for each channel in bits. - void set_channel_size(unsigned in_channel_num, - unsigned channel_size) { - if (in_channel_num < _channel_num) - return; - _channel_num = in_channel_num; - _channel_size = channel_size / 8; - _total_size = _channel_size * _channel_num; - } -}; - -/// 2D or 3D matrix data for image. -class image_matrix { - image_channel _channel; - int _range[3] = {1, 1, 1}; - int _dims = 0; - void *_host_data = nullptr; - - /// Set range of each dimension. - template void set_range(sycl::range range) { - for (int i = 0; i < dimensions; ++i) - _range[i] = range[i]; - _dims = dimensions; - } - - template - sycl::range get_range(integer_sequence) { - return sycl::range(_range[DimIdx]...); - } - -public: - /// Constructor with channel info and dimension size info. - template - image_matrix(image_channel channel, sycl::range range) - : _channel(channel) { - set_range(range); - _host_data = std::malloc(range.size() * _channel.get_total_size()); - } - image_matrix(sycl::image_channel_type channel_type, unsigned channel_num, - size_t x, size_t y) { - _channel.set_channel_type(channel_type); - _channel.set_channel_num(channel_num); - _dims = 1; - _range[0] = x; - if (y) { - _dims = 2; - _range[1] = y; - } - _host_data = std::malloc(_range[0] * _range[1] * _channel.get_total_size()); - } - - /// Construct a new image class with the matrix data. - template sycl::image *create_image() { - return create_image(_channel); - } - /// Construct a new image class with the matrix data. - template - sycl::image *create_image(image_channel channel) { - return new sycl::image( - _host_data, channel.get_channel_order(), channel.get_channel_type(), - get_range(make_index_sequence()), - sycl::property::image::use_host_ptr()); - } - - /// Get channel info. - inline image_channel get_channel() { return _channel; } - /// Get range of the image. - sycl::range<3> get_range() { - return sycl::range<3>(_range[0], _range[1], _range[2]); - } - /// Get matrix dims. - inline int get_dims() { return _dims; } - /// Convert to pitched data. - pitched_data to_pitched_data() { - return pitched_data(_host_data, _range[0] * _channel.get_total_size(), - _range[0], _range[1]); - } - - ~image_matrix() { - if (_host_data) - std::free(_host_data); - _host_data = nullptr; - } -}; -using image_matrix_p = image_matrix *; - -enum class image_data_type { matrix, linear, pitch, unsupport }; - -/// Image data info. -class image_data { -public: - image_data() { _type = image_data_type::unsupport; } - image_data(image_matrix_p matrix_data) { set_data(matrix_data); } - image_data(void *data_ptr, size_t x_size, image_channel channel) { - set_data(data_ptr, x_size, channel); - } - image_data(void *data_ptr, size_t x_size, size_t y_size, size_t pitch_size, - image_channel channel) { - set_data(data_ptr, x_size, y_size, pitch_size, channel); - } - void set_data(image_matrix_p matrix_data) { - _type = image_data_type::matrix; - _data = matrix_data; - _channel = matrix_data->get_channel(); - } - void set_data(void *data_ptr, size_t x_size, image_channel channel) { - _type = image_data_type::linear; - _data = data_ptr; - _x = x_size; - _channel = channel; - } - void set_data(void *data_ptr, size_t x_size, size_t y_size, size_t pitch_size, - image_channel channel) { - _type = image_data_type::pitch; - _data = data_ptr; - _x = x_size; - _y = y_size; - _pitch = pitch_size; - _channel = channel; - } - - image_data_type get_data_type() const { return _type; } - void set_data_type(image_data_type type) { _type = type; } - - void *get_data_ptr() const { return _data; } - void set_data_ptr(void *data) { _data = data; } - - size_t get_x() const { return _x; } - void set_x(size_t x) { _x = x; } - - size_t get_y() const { return _y; } - void set_y(size_t y) { _y = y; } - - size_t get_pitch() const { return _pitch; } - void set_pitch(size_t pitch) { _pitch = pitch; } - - image_channel get_channel() const { return _channel; } - void set_channel(image_channel channel) { _channel = channel; } - - image_channel_data_type get_channel_data_type() { - return _channel.get_channel_data_type(); - } - void set_channel_data_type(image_channel_data_type type) { - _channel.set_channel_data_type(type); - } - - unsigned get_channel_size() { return _channel.get_channel_size(); } - void set_channel_size(unsigned channel_num, unsigned channel_size) { - return _channel.set_channel_size(channel_num, channel_size); - } - - unsigned get_channel_num() { return _channel.get_channel_num(); } - void set_channel_num(unsigned num) { - return _channel.set_channel_num(num); - } - - sycl::image_channel_type get_channel_type() { - return _channel.get_channel_type(); - } - void set_channel_type(sycl::image_channel_type type) { - return _channel.set_channel_type(type); - } - -private: - image_data_type _type; - void *_data = nullptr; - size_t _x, _y, _pitch; - image_channel _channel; -}; - -/// Image sampling info, include addressing mode, filtering mode and -/// normalization info. -class sampling_info { - sycl::addressing_mode _addressing_mode = - sycl::addressing_mode::clamp_to_edge; - sycl::filtering_mode _filtering_mode = sycl::filtering_mode::nearest; - sycl::coordinate_normalization_mode _coordinate_normalization_mode = - sycl::coordinate_normalization_mode::unnormalized; - -public: - sycl::addressing_mode get_addressing_mode() { return _addressing_mode; } - void set(sycl::addressing_mode addressing_mode) { _addressing_mode = addressing_mode; } - - sycl::filtering_mode get_filtering_mode() { return _filtering_mode; } - void set(sycl::filtering_mode filtering_mode) { _filtering_mode = filtering_mode; } - - sycl::coordinate_normalization_mode get_coordinate_normalization_mode() { - return _coordinate_normalization_mode; - } - void set(sycl::coordinate_normalization_mode coordinate_normalization_mode) { - _coordinate_normalization_mode = coordinate_normalization_mode; - } - - bool is_coordinate_normalized() { - return _coordinate_normalization_mode == - sycl::coordinate_normalization_mode::normalized; - } - void set_coordinate_normalization_mode(int is_normalized) { - _coordinate_normalization_mode = - is_normalized ? sycl::coordinate_normalization_mode::normalized - : sycl::coordinate_normalization_mode::unnormalized; - } - void - set(sycl::addressing_mode addressing_mode, - sycl::filtering_mode filtering_mode, - sycl::coordinate_normalization_mode coordinate_normalization_mode) { - set(addressing_mode); - set(filtering_mode); - set(coordinate_normalization_mode); - } - void set(sycl::addressing_mode addressing_mode, - sycl::filtering_mode filtering_mode, int is_normalized) { - set(addressing_mode); - set(filtering_mode); - set_coordinate_normalization_mode(is_normalized); - } - - sycl::sampler get_sampler() { - return sycl::sampler(_coordinate_normalization_mode, _addressing_mode, - _filtering_mode); - } -}; - -/// Image base class. -class image_wrapper_base { - sampling_info _sampling_info; - image_data _data; - -public: - virtual ~image_wrapper_base() = 0; - - void attach(image_data data) { set_data(data); } - /// Attach matrix data to this class. - void attach(image_matrix *matrix) { - detach(); - image_wrapper_base::set_data(image_data(matrix)); - } - /// Attach matrix data to this class. - void attach(image_matrix *matrix, image_channel channel) { - attach(matrix); - image_wrapper_base::set_channel(channel); - } - /// Attach linear data to this class. - void attach(const void *ptr, size_t count) { - attach(ptr, count, get_channel()); - } - /// Attach linear data to this class. - void attach(const void *ptr, size_t count, image_channel channel) { - detach(); - image_wrapper_base::set_data(image_data(const_cast(ptr), count, channel)); - } - /// Attach 2D data to this class. - void attach(const void *data, size_t x, size_t y, size_t pitch) { - attach(data, x, y, pitch, get_channel()); - } - /// Attach 2D data to this class. - void attach(const void *data, size_t x, size_t y, size_t pitch, - image_channel channel) { - detach(); - image_wrapper_base::set_data( - image_data(const_cast(data), x, y, pitch, channel)); - } - /// Detach data. - virtual void detach() {} - - sampling_info get_sampling_info() { return _sampling_info; } - void set_sampling_info(sampling_info info) { - _sampling_info = info; - } - const image_data &get_data() { return _data; } - void set_data(image_data data) { _data = data; } - - image_channel get_channel() { return _data.get_channel(); } - void set_channel(image_channel channel) { _data.set_channel(channel); } - - image_channel_data_type get_channel_data_type() { - return _data.get_channel_data_type(); - } - void set_channel_data_type(image_channel_data_type type) { - _data.set_channel_data_type(type); - } - - unsigned get_channel_size() { return _data.get_channel_size(); } - void set_channel_size(unsigned channel_num, unsigned channel_size) { - return _data.set_channel_size(channel_num, channel_size); - } - - sycl::addressing_mode get_addressing_mode() { - return _sampling_info.get_addressing_mode(); - } - void set(sycl::addressing_mode addressing_mode) { - _sampling_info.set(addressing_mode); - } - - sycl::filtering_mode get_filtering_mode() { - return _sampling_info.get_filtering_mode(); - } - void set(sycl::filtering_mode filtering_mode) { - _sampling_info.set(filtering_mode); - } - - sycl::coordinate_normalization_mode get_coordinate_normalization_mode() { - return _sampling_info.get_coordinate_normalization_mode(); - } - void - set(sycl::coordinate_normalization_mode coordinate_normalization_mode) { - _sampling_info.set(coordinate_normalization_mode); - } - - bool is_coordinate_normalized() { - return _sampling_info.is_coordinate_normalized(); - } - void set_coordinate_normalization_mode(int is_normalized) { - _sampling_info.set_coordinate_normalization_mode(is_normalized); - } - void - set(sycl::addressing_mode addressing_mode, - sycl::filtering_mode filtering_mode, - sycl::coordinate_normalization_mode coordinate_normalization_mode) { - set(addressing_mode); - set(filtering_mode); - set(coordinate_normalization_mode); - } - void set(sycl::addressing_mode addressing_mode, - sycl::filtering_mode filtering_mode, int is_normalized) { - set(addressing_mode); - set(filtering_mode); - set_coordinate_normalization_mode(is_normalized); - } - - unsigned get_channel_num() { return _data.get_channel_num(); } - void set_channel_num(unsigned num) { - return _data.set_channel_num(num); - } - - sycl::image_channel_type get_channel_type() { - return _data.get_channel_type(); - } - void set_channel_type(sycl::image_channel_type type) { - return _data.set_channel_type(type); - } - - sycl::sampler get_sampler() { - sycl::sampler smp = _sampling_info.get_sampler(); - /// linear memory only used for sycl::filtering_mode::nearest. - if (_data.get_data_type() == image_data_type::linear) { - smp = sycl::sampler(smp.get_coordinate_normalization_mode(), - smp.get_addressing_mode(), - sycl::filtering_mode::nearest); - } - return smp; - } -}; -inline image_wrapper_base::~image_wrapper_base() {} -using image_wrapper_base_p = image_wrapper_base *; - -template class image_accessor_ext; - -/// Image class, wrapper of sycl::image. -template class image_wrapper : public image_wrapper_base { - sycl::image *_image = nullptr; - -#ifndef DPCT_USM_LEVEL_NONE - std::vector _host_buffer; -#endif - - void create_image(sycl::queue q) { - auto &data = get_data(); - if (data.get_data_type() == image_data_type::matrix) { - _image = static_cast(data.get_data_ptr()) - ->create_image(data.get_channel()); - return; - } - auto ptr = data.get_data_ptr(); - auto channel = data.get_channel(); - - if (detail::get_pointer_attribute(q, ptr) == detail::pointer_access_attribute::device_only) { -#ifdef DPCT_USM_LEVEL_NONE - ptr = get_buffer(ptr) - .template get_access() - .get_pointer(); -#else - auto sz = data.get_x(); - if (data.get_data_type() == image_data_type::pitch) - sz *= channel.get_total_size() * data.get_y(); - _host_buffer.resize(sz); - q.memcpy(_host_buffer.data(), ptr, sz).wait(); - ptr = _host_buffer.data(); -#endif - } - - if constexpr (dimensions == 1) { - assert(data.get_data_type() == image_data_type::linear); - _image = new sycl::image<1>( - ptr, channel.get_channel_order(), channel.get_channel_type(), - sycl::range<1>(data.get_x() / channel.get_total_size())); - } else if constexpr (dimensions == 2) { - assert(data.get_data_type() == image_data_type::pitch); - _image = new sycl::image<2>(ptr, channel.get_channel_order(), - channel.get_channel_type(), - sycl::range<2>(data.get_x(), data.get_y()), - sycl::range<1>(data.get_pitch())); - } else { - throw std::runtime_error("3D image only support matrix data"); - } - return; - } - -public: - using acc_data_t = typename detail::image_trait::acc_data_t; - using accessor_t = - typename image_accessor_ext::accessor_t; - - image_wrapper() { set_channel(image_channel::create()); } - ~image_wrapper() { detach(); } - - /// Get image accessor. - accessor_t get_access(sycl::handler &cgh, sycl::queue &q = get_default_queue()) { - if (!_image) - create_image(q); - return accessor_t(*_image, cgh); - } - - /// Detach data. - void detach() override { - if (_image) - delete _image; - _image = nullptr; - } -}; - -/// Wrap sampler and image accessor together. -template -class image_accessor_ext { -public: - using accessor_t = - typename detail::image_trait::template accessor_t; - using data_t = typename detail::image_trait::data_t; - sycl::sampler _sampler; - accessor_t _img_acc; - -public: - image_accessor_ext(sycl::sampler sampler, accessor_t acc) - : _sampler(sampler), _img_acc(acc) {} - - /// Read data from accessor. - template - typename std::enable_if::type read(float x, float y, - float z) { - return detail::fetch_data()( - _img_acc.read(sycl::float4(x, y, z, 0), _sampler)); - } - /// Read data from accessor. - template ::value - &&std::is_integral::value - &&std::is_integral::value> - typename std::enable_if::type read(Coord0 x, Coord1 y, - Coord2 z) { - return detail::fetch_data()( - _img_acc.read(sycl::int4(x, y, z, 0), _sampler)); - } - /// Read data from accessor. - template - typename std::enable_if::type read(float x, float y) { - return detail::fetch_data()( - _img_acc.read(sycl::float2(x, y), _sampler)); - } - /// Read data from accessor. - template ::value - &&std::is_integral::value> - typename std::enable_if::type read(Coord0 x, Coord1 y) { - return detail::fetch_data()( - _img_acc.read(sycl::int2(x, y), _sampler)); - } - /// Read data from accessor. - template - typename std::enable_if::type read(float x) { - return detail::fetch_data()(_img_acc.read(x, _sampler)); - } - /// Read data from accessor. - template ::value> - typename std::enable_if::type read(CoordT x) { - return detail::fetch_data()(_img_acc.read(x, _sampler)); - } -}; - -template class image_accessor_ext { -public: - using accessor_t = - typename detail::image_trait::template array_accessor_t; - using data_t = typename detail::image_trait::data_t; - sycl::sampler _sampler; - accessor_t _img_acc; - -public: - image_accessor_ext(sycl::sampler sampler, accessor_t acc) - : _sampler(sampler), _img_acc(acc) {} - - /// Read data from accessor. - template - typename std::enable_if::type read(int index, float x, - float y) { - return detail::fetch_data()( - _img_acc[index].read(sycl::float2(x, y), _sampler)); - } - /// Read data from accessor. - template - typename std::enable_if::type read(int index, int x, int y) { - return detail::fetch_data()( - _img_acc[index].read(sycl::int2(x, y), _sampler)); - } - /// Read data from accessor. - template - typename std::enable_if::type read(int index, float x) { - return detail::fetch_data()( - _img_acc[index].read(x, _sampler)); - } - /// Read data from accessor. - template - typename std::enable_if::type read(int index, int x) { - return detail::fetch_data()( - _img_acc[index].read(x, _sampler)); - } -}; - -/// Create image wrapper according to image data and sampling info. -/// \return Pointer to image wrapper base class. -/// \param data Image data used to create image wrapper. -/// \param info Image sampling info used to create image wrapper. -/// \returns Pointer to base class of created image wrapper object. -static inline image_wrapper_base *create_image_wrapper(image_data data, - sampling_info info) { - image_channel channel; - int dims = 1; - if (data.get_data_type() == image_data_type::matrix) { - auto matrix = (image_matrix_p)data.get_data_ptr(); - channel = matrix->get_channel(); - dims = matrix->get_dims(); - } else { - if (data.get_data_type() == image_data_type::pitch) { - dims = 2; - } - channel = data.get_channel(); - } - - if (auto ret = detail::create_image_wrapper(channel, dims)) { - ret->set_sampling_info(info); - ret->set_data(data); - return ret; - } - return nullptr; -} - -namespace detail { -/// Create image according with given type \p T and \p dims. -template static image_wrapper_base *create_image_wrapper(int dims) { - switch (dims) { - case 1: - return new image_wrapper(); - case 2: - return new image_wrapper(); - case 3: - return new image_wrapper(); - default: - return nullptr; - } -} -/// Create image with given data type \p T, channel order and dims -template -static image_wrapper_base *create_image_wrapper(unsigned channel_num, int dims) { - switch (channel_num) { - case 1: - return create_image_wrapper(dims); - case 2: - return create_image_wrapper>(dims); - case 3: - return create_image_wrapper>(dims); - case 4: - return create_image_wrapper>(dims); - default: - return nullptr; - } -} - -/// Create image with channel info and specified dimensions. -static image_wrapper_base *create_image_wrapper(image_channel channel, int dims) { - switch (channel.get_channel_type()) { - case sycl::image_channel_type::fp16: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::fp32: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::signed_int8: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::signed_int16: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::signed_int32: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::unsigned_int8: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::unsigned_int16: - return create_image_wrapper(channel.get_channel_num(), dims); - case sycl::image_channel_type::unsigned_int32: - return create_image_wrapper(channel.get_channel_num(), dims); - default: - return nullptr; - } -} -} // namespace detail - -} // namespace dpct - -#endif // !__DPCT_IMAGE_HPP__ diff --git a/dpct/kernel.hpp b/dpct/kernel.hpp deleted file mode 100644 index 11d1321bb..000000000 --- a/dpct/kernel.hpp +++ /dev/null @@ -1,459 +0,0 @@ -//==---- kernel.hpp -------------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_KERNEL_HPP__ -#define __DPCT_KERNEL_HPP__ - -#include -#ifdef _WIN32 -#include -#include -#else -#include -#endif - -#if defined(__has_include) && __has_include() -#include -#elif defined(__has_include) && __has_include() -#include -#else -#error "SYCLomatic runtime requires C++ filesystem support" -#endif - -#include -#include -#include - -namespace dpct { - -typedef void (*kernel_functor)(sycl::queue &, const sycl::nd_range<3> &, - unsigned int, void **, void **); - -struct kernel_function_info { - int max_work_group_size = 0; -}; - -static inline void get_kernel_function_info(kernel_function_info *kernel_info, - const void *function) { - kernel_info->max_work_group_size = - dpct::dev_mgr::instance() - .current_device() - .get_info(); -} -static inline kernel_function_info -get_kernel_function_info(const void *function) { - kernel_function_info kernel_info; - kernel_info.max_work_group_size = - dpct::dev_mgr::instance() - .current_device() - .get_info(); - return kernel_info; -} - - -namespace detail { - -#if defined(__has_include) && __has_include() -namespace fs = std::filesystem; -#else -namespace fs = std::experimental::filesystem; -#endif - -/// Write data to temporary file and return absolute path to temporary file. -/// Temporary file is created in a temporary directory both of which have random -/// names with only the user having access permissions. Only one temporary file -/// will be created in the temporary directory. -static inline fs::path write_data_to_file(char const *const data, size_t size) { - std::error_code ec; - - if (sizeof(size_t) >= sizeof(std::streamsize) && - size > (std::numeric_limits::max)()) - throw std::runtime_error("data file too large"); - - // random number generator - std::random_device dev; - std::mt19937 prng(dev()); - std::uniform_int_distribution rand(0); - - // find temporary directory - auto tmp_dir = fs::temp_directory_path(ec); - if (ec) - throw std::runtime_error("could not find temporary directory"); - - // create private directory - std::stringstream directory; - fs::path directory_path; - constexpr int max_attempts = 5; - int i; - - for (i = 0; i < max_attempts; i++) { - directory << std::hex << rand(prng); - directory_path = tmp_dir / directory.str(); - if (fs::create_directory(directory_path)) { - break; - } - } - if (i == max_attempts) - throw std::runtime_error("could not create directory"); - - // only allow owner permissions to private directory - fs::permissions(directory_path, fs::perms::owner_all, ec); - if (ec) - throw std::runtime_error("could not set directory permissions"); - - // random filename in private directory - std::stringstream filename; - filename << std::hex << rand(prng); -#ifdef _WIN32 - auto filepath = directory_path / (filename.str() + ".dll"); -#else - auto filepath = directory_path / filename.str(); -#endif - - // write data to temporary file - auto outfile = std::ofstream(filepath, std::ios::out | std::ios::binary); - if (outfile) { - // only allow program to write file - fs::permissions(filepath, fs::perms::owner_write, ec); - if (ec) - throw std::runtime_error("could not set permissions"); - - outfile.write(data, size); - if (!outfile.good()) - throw std::runtime_error("could not write data"); - outfile.close(); - - // only allow program to read/execute file - fs::permissions(filepath, fs::perms::owner_read | fs::perms::owner_exec, - ec); - if (ec) - throw std::runtime_error("could not set permissions"); - } else - throw std::runtime_error("could not write data"); - - // check temporary file contents - auto infile = std::ifstream(filepath, std::ios::in | std::ios::binary); - if (infile) { - bool mismatch = false; - size_t cnt = 0; - - while (1) { - char c; - infile.get(c); - if (infile.eof()) - break; - if (c != data[cnt++]) - mismatch = true; - } - if (cnt != size || mismatch) - throw std::runtime_error("file contents not written correctly"); - } else - throw std::runtime_error("could not validate file"); - - if (!filepath.is_absolute()) - throw std::runtime_error("temporary filepath is not absolute"); - - return filepath; -} - -static inline uint16_t extract16(unsigned char const *const ptr) { - uint16_t ret = 0; - - ret |= static_cast(ptr[0]) << 0; - ret |= static_cast(ptr[1]) << 8; - - return (ret); -} - -static inline uint32_t extract32(unsigned char const *const ptr) { - uint32_t ret = 0; - - ret |= static_cast(ptr[0]) << 0; - ret |= static_cast(ptr[1]) << 8; - ret |= static_cast(ptr[2]) << 16; - ret |= static_cast(ptr[3]) << 24; - - return (ret); -} - -static inline uint64_t extract64(unsigned char const *const ptr) { - uint64_t ret = 0; - - ret |= static_cast(ptr[0]) << 0; - ret |= static_cast(ptr[1]) << 8; - ret |= static_cast(ptr[2]) << 16; - ret |= static_cast(ptr[3]) << 24; - ret |= static_cast(ptr[4]) << 32; - ret |= static_cast(ptr[5]) << 40; - ret |= static_cast(ptr[6]) << 48; - ret |= static_cast(ptr[7]) << 56; - - return (ret); -} - -static inline uint64_t get_lib_size(char const *const blob) { -#ifdef _WIN32 - /////////////////////////////////////////////////////////////////////// - // Analyze DOS stub - unsigned char const *const ublob = - reinterpret_cast(blob); - if (ublob[0] != 0x4d || ublob[1] != 0x5a) { - throw std::runtime_error("Blob is not a Windows DLL."); - } - uint32_t pe_header_offset = extract32(ublob + 0x3c); - - /////////////////////////////////////////////////////////////////////// - // Ananlyze PE-header - unsigned char const *const pe_header = ublob + pe_header_offset; - - // signature - uint32_t pe_signature = extract32(pe_header + 0); - if (pe_signature != 0x00004550) { - throw std::runtime_error("PE-header signature is not 0x00004550"); - } - - // machine - uint16_t machine = extract16(pe_header + 4); - if (machine != 0x8664) { - throw std::runtime_error("Only DLLs for x64 supported"); - } - - // number of sections - uint16_t number_of_sections = extract16(pe_header + 6); - - // sizeof optional header - uint16_t sizeof_optional_header = extract16(pe_header + 20); - - // magic - uint16_t magic = extract16(pe_header + 24); - if (magic != 0x10b && magic != 0x20b) { - throw std::runtime_error("MAGIC is not 0x010b or 0x020b"); - } - - /////////////////////////////////////////////////////////////////////// - // Analyze tail of optional header - constexpr int coff_header_size = 24; - - unsigned char const *const tail_of_optional_header = - pe_header + coff_header_size + sizeof_optional_header; - if (extract64(tail_of_optional_header - 8) != 0) { - throw std::runtime_error("Optional header not zero-padded"); - } - - /////////////////////////////////////////////////////////////////////// - // Analyze last section header - constexpr int section_header_size = 40; - unsigned char const *const last_section_header = - tail_of_optional_header + section_header_size * (number_of_sections - 1); - - uint32_t sizeof_raw_data = extract32(last_section_header + 16); - uint32_t pointer_to_raw_data = extract32(last_section_header + 20); - - return sizeof_raw_data + pointer_to_raw_data; -#else - if (blob[0] != 0x7F || blob[1] != 'E' || blob[2] != 'L' || blob[3] != 'F') - throw std::runtime_error("Blob is not in ELF format"); - - if (blob[4] != 0x02) - throw std::runtime_error("Only 64-bit headers are supported"); - - if (blob[5] != 0x01) - throw std::runtime_error("Only little-endian headers are supported"); - - unsigned char const *const ublob = - reinterpret_cast(blob); - uint64_t e_shoff = extract64(ublob + 0x28); - uint16_t e_shentsize = extract16(ublob + 0x3A); - uint16_t e_shnum = extract16(ublob + 0x3C); - - return e_shoff + (e_shentsize * e_shnum); -#endif -} - -#ifdef _WIN32 -class path_lib_record { -public: - void operator=(const path_lib_record &) = delete; - ~path_lib_record() { - for (auto entry : lib_to_path) { - FreeLibrary(static_cast(entry.first)); - fs::permissions(entry.second, fs::perms::owner_all); - fs::remove_all(entry.second.remove_filename()); - } - } - static void record_lib_path(fs::path path, void *library) { - lib_to_path[library] = path; - } - static void remove_lib(void *library) { - auto path = lib_to_path[library]; - std::error_code ec; - - FreeLibrary(static_cast(library)); - fs::permissions(path, fs::perms::owner_all); - if (fs::remove_all(path.remove_filename(), ec) != 2 || ec) - // one directory and one temporary file should have been deleted - throw std::runtime_error("Directory delete failed"); - - lib_to_path.erase(library); - } - -private: - static inline std::unordered_map lib_to_path; -}; -#endif - -} // namespace detail - -class kernel_library { -public: - kernel_library() : ptr{nullptr} {} - kernel_library(void *ptr) : ptr{ptr} {} - - operator void *() const { return ptr; } - -private: - void *ptr; -#ifdef _WIN32 - static inline detail::path_lib_record single_instance_to_trigger_destructor; -#endif -}; - -namespace detail { - -static inline kernel_library load_dl_from_data(char const *const data, - size_t size) { - fs::path filename = write_data_to_file(data, size); -#ifdef _WIN32 - void *so = LoadLibraryW(filename.wstring().c_str()); -#else - void *so = dlopen(filename.c_str(), RTLD_LAZY); -#endif - if (so == nullptr) - throw std::runtime_error("Failed to load kernel library"); - -#ifdef _WIN32 - detail::path_lib_record::record_lib_path(filename, so); -#else - std::error_code ec; - - // Windows DLL cannot be deleted while in use - if (fs::remove_all(filename.remove_filename(), ec) != 2 || ec) - // one directory and one temporary file should have been deleted - throw std::runtime_error("Directory delete failed"); -#endif - - return so; -} - -} // namespace detail - -/// Load kernel library and return a handle to use the library. -/// \param [in] name The name of the library. -static inline kernel_library load_kernel_library(const std::string &name) { - std::ifstream ifs; - ifs.open(name, std::ios::in | std::ios::binary); - - std::stringstream buffer; - buffer << ifs.rdbuf(); - - const std::string buffer_string = buffer.str(); - return detail::load_dl_from_data(buffer_string.c_str(), buffer_string.size()); -} - -/// Load kernel library whose image is alreay in memory and return a handle to -/// use the library. -/// \param [in] image A pointer to the image in memory. -static inline kernel_library load_kernel_library_mem(char const *const image) { - const size_t size = detail::get_lib_size(image); - - return detail::load_dl_from_data(image, size); -} - -/// Unload kernel library. -/// \param [in,out] library Handle to the library to be closed. -static inline void unload_kernel_library(const kernel_library &library) { -#ifdef _WIN32 - detail::path_lib_record::remove_lib(library); -#else - dlclose(library); -#endif -} - -class kernel_function { -public: - kernel_function() : ptr{nullptr} {} - kernel_function(dpct::kernel_functor ptr) : ptr{ptr} {} - - operator void *() const { return ((void *)ptr); } - - void operator()(sycl::queue &q, const sycl::nd_range<3> &range, - unsigned int a, void **args, void **extra) { - ptr(q, range, a, args, extra); - } - -private: - dpct::kernel_functor ptr; -}; - -/// Find kernel function in a kernel library and return its address. -/// \param [in] library Handle to the kernel library. -/// \param [in] name Name of the kernel function. -static inline dpct::kernel_function -get_kernel_function(kernel_library &library, const std::string &name) { -#ifdef _WIN32 - dpct::kernel_functor fn = reinterpret_cast( - GetProcAddress(static_cast(static_cast(library)), - (name + std::string("_wrapper")).c_str())); -#else - dpct::kernel_functor fn = reinterpret_cast( - dlsym(library, (name + std::string("_wrapper")).c_str())); -#endif - if (fn == nullptr) - throw std::runtime_error("Failed to get function"); - return fn; -} - -/// Invoke a kernel function. -/// \param [in] function kernel function. -/// \param [in] queue SYCL queue used to execute kernel -/// \param [in] groupRange SYCL group range -/// \param [in] localRange SYCL local range -/// \param [in] localMemSize The size of local memory required by the kernel -/// function. -/// \param [in] kernelParams Array of pointers to kernel arguments. -/// \param [in] extra Extra arguments. -static inline void invoke_kernel_function(dpct::kernel_function &function, - sycl::queue &queue, - sycl::range<3> groupRange, - sycl::range<3> localRange, - unsigned int localMemSize, - void **kernelParams, void **extra) { - function(queue, sycl::nd_range<3>(groupRange * localRange, localRange), - localMemSize, kernelParams, extra); -} - -/// Find image wrapper in a kernel library and return its address. -/// \param [in] library Handle to the kernel library. -/// \param [in] name Name of the target image wrapper. -static inline dpct::image_wrapper_base_p -get_image_wrapper(dpct::kernel_library &library, const std::string &name) { -#ifdef _WIN32 - dpct::image_wrapper_base_p fn = - reinterpret_cast(GetProcAddress( - static_cast(static_cast(library)), name.c_str())); -#else - dpct::image_wrapper_base_p fn = reinterpret_cast( - dlsym(library, name.c_str())); -#endif - if (fn == nullptr) - throw std::runtime_error("Failed to get image"); - return fn; -} - -} // namespace dpct -#endif // __DPCT_KERNEL_HPP__ diff --git a/dpct/lapack_utils.hpp b/dpct/lapack_utils.hpp deleted file mode 100644 index dac77d577..000000000 --- a/dpct/lapack_utils.hpp +++ /dev/null @@ -1,1953 +0,0 @@ -//==---- lapack_utils.hpp -------------------------*- C++ -*----------------==// -// -// Copyright (C) Intel Corporation -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// See https://llvm.org/LICENSE.txt for license information. -// -//===----------------------------------------------------------------------===// - -#ifndef __DPCT_LAPACK_UTILS_HPP__ -#define __DPCT_LAPACK_UTILS_HPP__ - -#include "memory.hpp" -#include "util.hpp" -#include "lib_common_utils.hpp" - -#include -#include - -namespace dpct { -namespace lapack { -/// Computes all eigenvalues and, optionally, eigenvectors of a real generalized -/// symmetric definite eigenproblem using a divide and conquer method. -/// \return Returns 0 if no synchronous exception, otherwise returns 1. -/// \param [in] queue Device queue where calculations will be performed. It must -/// have the in_order property when using the USM mode (DPCT_USM_LEVEL_NONE is -/// not defined). -/// \param [in] itype Must be 1 or 2 or 3. Specifies the problem type to be solved. -/// \param [in] jobz Must be job::novec or job::vec. -/// \param [in] uplo Must be uplo::upper or uplo::lower. -/// \param [in] n The order of the matrices A and B. -/// \param [in,out] a The symmetric matrix A. -/// \param [in] lda The leading dimension of matrix A. -/// \param [in,out] b The symmetric matrix B. -/// \param [in] ldb The leading dimension of matrix B. -/// \param [out] w Eigenvalues. -/// \param [in] scratchpad Scratchpad memory to be used by the routine -/// for storing intermediate results. -/// \param [in] scratchpad_size Size of scratchpad memory as a number of -/// floating point elements of type T. -/// \param [out] info If lapack synchronous exception is caught, the value -/// returned from info() method of the exception is set to \p info. -template -inline int sygvd(sycl::queue &queue, std::int64_t itype, oneapi::mkl::job jobz, - oneapi::mkl::uplo uplo, int n, T *a, int lda, T *b, int ldb, - T *w, T *scratchpad, int scratchpad_size, int *info) { -#ifdef DPCT_USM_LEVEL_NONE - auto info_buf = get_buffer(info); - auto a_buffer = get_buffer(a); - auto b_buffer = get_buffer(b); - auto w_buffer = get_buffer(w); - auto scratchpad_buffer = get_buffer(scratchpad); - int info_val = 0; - int ret_val = 0; - try { - oneapi::mkl::lapack::sygvd(queue, itype, jobz, uplo, n, a_buffer, lda, - b_buffer, ldb, w_buffer, scratchpad_buffer, - scratchpad_size); - } catch (oneapi::mkl::lapack::exception const& e) { - std::cerr << "Unexpected exception caught during call to LAPACK API: sygvd" - << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - info_val = static_cast(e.info()); - ret_val = 1; - } catch (sycl::exception const& e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - ret_val = 1; - } - queue.submit([&, info_val](sycl::handler &cgh) { - auto info_acc = info_buf.get_access(cgh); - cgh.single_task>( - [=]() { info_acc[0] = info_val; }); - }); - return ret_val; -#else - try { - oneapi::mkl::lapack::sygvd(queue, itype, jobz, uplo, n, a, lda, b, ldb, w, - scratchpad, scratchpad_size); - } catch (oneapi::mkl::lapack::exception const& e) { - std::cerr << "Unexpected exception caught during call to LAPACK API: sygvd" - << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - int info_val = static_cast(e.info()); - queue.memcpy(info, &info_val, sizeof(int)).wait(); - return 1; - } catch (sycl::exception const& e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - queue.memset(info, 0, sizeof(int)).wait(); - return 1; - } - queue.memset(info, 0, sizeof(int)); - return 0; -#endif -} -/// Computes all the eigenvalues, and optionally, the eigenvectors of a complex -/// generalized Hermitian positive-definite eigenproblem using a divide and -/// conquer method. -/// \return Returns 0 if no synchronous exception, otherwise returns 1. -/// \param [in] queue Device queue where calculations will be performed. It must -/// have the in_order property when using the USM mode (DPCT_USM_LEVEL_NONE is -/// not defined). -/// \param [in] itype Must be 1 or 2 or 3. Specifies the problem type to be solved. -/// \param [in] jobz Must be job::novec or job::vec. -/// \param [in] uplo Must be uplo::upper or uplo::lower. -/// \param [in] n The order of the matrices A and B. -/// \param [in,out] a The Hermitian matrix A. -/// \param [in] lda The leading dimension of matrix A. -/// \param [in,out] b The Hermitian matrix B. -/// \param [in] ldb The leading dimension of matrix B. -/// \param [in] w Eigenvalues. -/// \param [in] scratchpad Scratchpad memory to be used by the routine -/// for storing intermediate results. -/// \param [in] scratchpad_size Size of scratchpad memory as a number of -/// floating point elements of type T. -/// \param [out] info If lapack synchronous exception is caught, the value -/// returned from info() method of the exception is set to \p info. -template -inline int hegvd(sycl::queue &queue, std::int64_t itype, oneapi::mkl::job jobz, - oneapi::mkl::uplo uplo, int n, T *a, int lda, T *b, int ldb, - Tw *w, T *scratchpad, int scratchpad_size, int *info) { - using Ty = typename DataType::T2; -#ifdef DPCT_USM_LEVEL_NONE - auto info_buf = get_buffer(info); - auto a_buffer = get_buffer(a); - auto b_buffer = get_buffer(b); - auto w_buffer = get_buffer(w); - auto scratchpad_buffer = get_buffer(scratchpad); - int info_val = 0; - int ret_val = 0; - try { - oneapi::mkl::lapack::hegvd(queue, itype, jobz, uplo, n, a_buffer, lda, - b_buffer, ldb, w_buffer, scratchpad_buffer, - scratchpad_size); - } catch (oneapi::mkl::lapack::exception const& e) { - std::cerr << "Unexpected exception caught during call to LAPACK API: hegvd" - << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - info_val = static_cast(e.info()); - ret_val = 1; - } catch (sycl::exception const& e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - ret_val = 1; - } - queue.submit([&, info_val](sycl::handler &cgh) { - auto info_acc = info_buf.get_access(cgh); - cgh.single_task>( - [=]() { info_acc[0] = info_val; }); - }); - return ret_val; -#else - try { - oneapi::mkl::lapack::hegvd(queue, itype, jobz, uplo, n, (Ty *)a, lda, (Ty *)b, - ldb, w, (Ty *)scratchpad, scratchpad_size); - } catch (oneapi::mkl::lapack::exception const& e) { - std::cerr << "Unexpected exception caught during call to LAPACK API: hegvd" - << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - int info_val = static_cast(e.info()); - queue.memcpy(info, &info_val, sizeof(int)).wait(); - return 1; - } catch (sycl::exception const& e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - queue.memset(info, 0, sizeof(int)).wait(); - return 1; - } - queue.memset(info, 0, sizeof(int)); - return 0; -#endif -} -/// Computes the Cholesky factorizations of a batch of symmetric (or Hermitian, -/// for complex data) positive-definite matrices. -/// \return Returns 0 if no synchronous exception, otherwise returns 1. -/// \param [in] queue Device queue where calculations will be performed. It must -/// have the in_order property when using the USM mode (DPCT_USM_LEVEL_NONE is -/// not defined). -/// \param [in] uplo Must be uplo::upper or uplo::lower. -/// \param [in] n The order of the matrix A. -/// \param [in,out] a Array of pointers to matrix A. -/// \param [in] lda The leading dimension of matrix A. -/// \param [out] info If lapack synchronous exception is caught, the value -/// returned from info() method of the exception is set to \p info. -/// \param [in] group_size The batch size. -template -inline int potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, int n, - T *a[], int lda, int *info, int group_size) { -#ifdef DPCT_USM_LEVEL_NONE - throw std::runtime_error("this API is unsupported when USM level is none"); -#else - using Ty = typename DataType::T2; - struct matrix_info_t { - oneapi::mkl::uplo uplo_info; - std::int64_t n_info; - std::int64_t lda_info; - std::int64_t group_size_info; - }; - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->uplo_info = uplo; - matrix_info->n_info = n; - matrix_info->lda_info = lda; - matrix_info->group_size_info = group_size; - std::int64_t scratchpad_size = 0; - sycl::event e; - Ty *scratchpad = nullptr; - try { - scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size( - queue, &(matrix_info->uplo_info), &(matrix_info->n_info), - &(matrix_info->lda_info), 1, &(matrix_info->group_size_info)); - scratchpad = sycl::malloc_device(scratchpad_size, queue); - e = oneapi::mkl::lapack::potrf_batch( - queue, &(matrix_info->uplo_info), &(matrix_info->n_info), (Ty **)a, - &(matrix_info->lda_info), 1, &(matrix_info->group_size_info), - scratchpad, scratchpad_size); - } catch (oneapi::mkl::lapack::batch_error const &be) { - std::cerr << "Unexpected exception caught during call to LAPACK API: " - "potrf_batch_scratchpad_size/potrf_batch" - << std::endl - << "reason: " << be.what() << std::endl - << "number: " << be.info() << std::endl; - int i = 0; - auto &ids = be.ids(); - std::vector info_vec(group_size); - for (auto const &e : be.exceptions()) { - try { - std::rethrow_exception(e); - } catch (oneapi::mkl::lapack::exception &e) { - std::cerr << "Exception " << ids[i] << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - info_vec[i] = e.info(); - i++; - } - } - queue.memcpy(info, info_vec.data(), group_size * sizeof(int)).wait(); - std::free(matrix_info); - if (scratchpad) - sycl::free(scratchpad, queue); - return 1; - } catch (sycl::exception const &e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - queue.memset(info, 0, group_size * sizeof(int)).wait(); - std::free(matrix_info); - if (scratchpad) - sycl::free(scratchpad, queue); - return 1; - } - queue.submit([&](sycl::handler &cgh) { - cgh.depends_on(e); - cgh.host_task([=] { - std::free(matrix_info); - sycl::free(scratchpad, queue); - }); - }); - queue.memset(info, 0, group_size * sizeof(int)); - return 0; -#endif -} -/// Solves a batch of systems of linear equations with a Cholesky-factored -/// symmetric (Hermitian) positive-definite coefficient matrices. -/// \return Returns 0 if no synchronous exception, otherwise returns 1. -/// \param [in] queue Device queue where calculations will be performed. It must -/// have the in_order property when using the USM mode (DPCT_USM_LEVEL_NONE is -/// not defined). -/// \param [in] uplo Must be uplo::upper or uplo::lower. -/// \param [in] n The order of the matrix A. -/// \param [in] nrhs The number of right-hand sides. -/// \param [in,out] a Array of pointers to matrix A. -/// \param [in] lda The leading dimension of matrix A. -/// \param [in,out] b Array of pointers to matrix B. -/// \param [in] ldb The leading dimension of matrix B. -/// \param [out] info If lapack synchronous exception is caught, the value -/// returned from info() method of the exception is set to \p info. -/// \param [in] group_size The batch size. -template -inline int potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, int n, - int nrhs, T *a[], int lda, T *b[], int ldb, int *info, - int group_size) { -#ifdef DPCT_USM_LEVEL_NONE - throw std::runtime_error("this API is unsupported when USM level is none"); -#else - using Ty = typename DataType::T2; - struct matrix_info_t { - oneapi::mkl::uplo uplo_info; - std::int64_t n_info; - std::int64_t nrhs_info; - std::int64_t lda_info; - std::int64_t ldb_info; - std::int64_t group_size_info; - }; - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->uplo_info = uplo; - matrix_info->n_info = n; - matrix_info->nrhs_info = nrhs; - matrix_info->lda_info = lda; - matrix_info->ldb_info = ldb; - matrix_info->group_size_info = group_size; - std::int64_t scratchpad_size = 0; - sycl::event e; - Ty *scratchpad = nullptr; - try { - scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size( - queue, &(matrix_info->uplo_info), &(matrix_info->n_info), - &(matrix_info->nrhs_info), &(matrix_info->lda_info), - &(matrix_info->ldb_info), 1, &(matrix_info->group_size_info)); - scratchpad = sycl::malloc_device(scratchpad_size, queue); - e = oneapi::mkl::lapack::potrs_batch( - queue, &(matrix_info->uplo_info), &(matrix_info->n_info), - &(matrix_info->nrhs_info), (Ty **)a, &(matrix_info->lda_info), (Ty **)b, - &(matrix_info->ldb_info), 1, &(matrix_info->group_size_info), - scratchpad, scratchpad_size); - } catch (oneapi::mkl::lapack::batch_error const &be) { - std::cerr << "Unexpected exception caught during call to LAPACK API: " - "potrs_batch_scratchpad_size/potrs_batch" - << std::endl - << "reason: " << be.what() << std::endl - << "number: " << be.info() << std::endl; - int i = 0; - auto &ids = be.ids(); - std::vector info_vec(group_size); - for (auto const &e : be.exceptions()) { - try { - std::rethrow_exception(e); - } catch (oneapi::mkl::lapack::exception &e) { - std::cerr << "Exception " << ids[i] << std::endl - << "reason: " << e.what() << std::endl - << "info: " << e.info() << std::endl; - info_vec[i] = e.info(); - i++; - } - } - queue.memcpy(info, info_vec.data(), group_size * sizeof(int)).wait(); - std::free(matrix_info); - if (scratchpad) - sycl::free(scratchpad, queue); - return 1; - } catch (sycl::exception const &e) { - std::cerr << "Caught synchronous SYCL exception:" << std::endl - << "reason: " << e.what() << std::endl; - queue.memset(info, 0, group_size * sizeof(int)).wait(); - std::free(matrix_info); - if (scratchpad) - sycl::free(scratchpad, queue); - return 1; - } - queue.submit([&](sycl::handler &cgh) { - cgh.depends_on(e); - cgh.host_task([=] { - std::free(matrix_info); - sycl::free(scratchpad, queue); - }); - }); - queue.memset(info, 0, group_size * sizeof(int)); - return 0; -#endif -} - -namespace detail { -template