diff --git a/ggml-sycl/dpct/helper.hpp b/ggml-sycl/dpct/helper.hpp index 97ff5b39d..627599326 100644 --- a/ggml-sycl/dpct/helper.hpp +++ b/ggml-sycl/dpct/helper.hpp @@ -77,7 +77,7 @@ inline std::string get_device_type_name(const sycl::device &Device) { inline std::string get_device_backend_and_type(const sycl::device &device) { std::stringstream device_type; sycl::backend backend = device.get_backend(); - device_type << backend << ":" << get_device_type_name(device); + device_type << backend << ":" << get_device_type_name(device); return device_type.str(); } @@ -220,8 +220,7 @@ namespace dpct // a. and b. i++; minor = std::stoi(&(ver[i])); - } - else { + } else { // c. minor = 0; } @@ -232,7 +231,7 @@ namespace dpct { public: generic_error_type() = default; - generic_error_type(T value) : value{ value } {} + generic_error_type(T value) : value{value} {} operator T() const { return value; } private: @@ -241,7 +240,7 @@ namespace dpct } // namespace detail - /// Pitched 2D/3D memory data. + /// Pitched 2D/3D memory data. class pitched_data { public: @@ -577,7 +576,7 @@ namespace dpct prop.set_max_work_items_per_compute_unit( dev.get_info()); - int max_nd_range_size[] = { 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF }; + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; prop.set_max_nd_range_size(max_nd_range_size); // Estimates max register size per work group, feel free to update the value @@ -590,94 +589,75 @@ namespace dpct } /// dpct device extension - class device_ext : public sycl::device - { + class device_ext : public sycl::device { typedef std::mutex mutex_type; public: device_ext() : sycl::device() {} - ~device_ext() - { + ~device_ext() { std::lock_guard lock(m_mutex); clear_queues(); } - device_ext(const sycl::device &base) : sycl::device(base) - { + device_ext(const sycl::device &base) : sycl::device(base) { std::lock_guard lock(m_mutex); init_queues(); } int is_native_atomic_supported() { return 0; } - int get_major_version() const - { - return dpct::get_major_version(*this); - } + int get_major_version() const { return dpct::get_major_version(*this); } - int get_minor_version() const - { - return dpct::get_minor_version(*this); - } + int get_minor_version() const { return dpct::get_minor_version(*this); } - int get_max_compute_units() const - { + int get_max_compute_units() const { return get_device_info().get_max_compute_units(); } /// Return the maximum clock frequency of this device in KHz. - int get_max_clock_frequency() const - { + int get_max_clock_frequency() const { return get_device_info().get_max_clock_frequency(); } int get_integrated() const { return get_device_info().get_integrated(); } - int get_max_sub_group_size() const - { + int get_max_sub_group_size() const { return get_device_info().get_max_sub_group_size(); } - int get_max_register_size_per_work_group() const - { + int get_max_register_size_per_work_group() const { return get_device_info().get_max_register_size_per_work_group(); } - int get_max_work_group_size() const - { + int get_max_work_group_size() const { return get_device_info().get_max_work_group_size(); } - int get_mem_base_addr_align() const - { + int get_mem_base_addr_align() const { return get_info(); } - size_t get_global_mem_size() const - { + size_t get_global_mem_size() const { return get_device_info().get_global_mem_size(); } - size_t get_max_mem_alloc_size() const - { + size_t get_max_mem_alloc_size() const { return get_device_info().get_max_mem_alloc_size(); } /// Get the number of bytes of free and total memory on the SYCL device. - /// \param [out] free_memory The number of bytes of free memory on the SYCL device. - /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) - { + /// \param [out] free_memory The number of bytes of free memory on the + /// SYCL device. \param [out] total_memory The number of bytes of total + /// memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) { total_memory = get_device_info().get_global_mem_size(); - const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " + const char *warning_info = + "get_memory_info: [warning] ext_intel_free_memory is not " "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " "use total memory as free memory"; #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) - if (!has(sycl::aspect::ext_intel_free_memory)) - { + if (!has(sycl::aspect::ext_intel_free_memory)) { std::cerr << warning_info << std::endl; free_memory = total_memory; - } - else - { + } else { free_memory = get_info(); } #else @@ -691,20 +671,17 @@ namespace dpct #endif } - void get_device_info(device_info &out) const - { + void get_device_info(device_info &out) const { dpct::get_device_info(out, *this); } - device_info get_device_info() const - { + device_info get_device_info() const { device_info prop; dpct::get_device_info(prop, *this); return prop; } - void reset() - { + void reset() { std::lock_guard lock(m_mutex); clear_queues(); init_queues(); @@ -714,25 +691,20 @@ namespace dpct sycl::queue &out_of_order_queue() { return _q_out_of_order; } - sycl::queue &default_queue() - { - return in_order_queue(); - } + sycl::queue &default_queue() { return in_order_queue(); } - void queues_wait_and_throw() - { + void queues_wait_and_throw() { std::unique_lock lock(m_mutex); lock.unlock(); - for (auto &q : _queues) - { + for (auto &q : _queues) { q.wait_and_throw(); } - // Guard the destruct of current_queues to make sure the ref count is safe. + // Guard the destruct of current_queues to make sure the ref count is + // safe. lock.lock(); } - sycl::queue create_queue(bool enable_exception_handler = false) - { + sycl::queue create_queue(bool enable_exception_handler = false) { return create_in_order_queue(enable_exception_handler); } @@ -754,52 +726,45 @@ namespace dpct sycl::property::queue::in_order()); } - sycl::queue create_out_of_order_queue(bool enable_exception_handler = false) { + sycl::queue create_out_of_order_queue( + bool enable_exception_handler = false) { std::lock_guard lock(m_mutex); return create_queue_impl(enable_exception_handler); } - void destroy_queue(sycl::queue queue) - { + void destroy_queue(sycl::queue queue) { std::lock_guard lock(m_mutex); _queues.clear(); } - void set_saved_queue(sycl::queue q) - { + void set_saved_queue(sycl::queue q) { std::lock_guard lock(m_mutex); _saved_queue = q; } - sycl::queue get_saved_queue() const - { + sycl::queue get_saved_queue() const { std::lock_guard lock(m_mutex); return _saved_queue; } private: - void clear_queues() - { - _queues.clear(); - } + void clear_queues() { _queues.clear(); } - void init_queues() - { - _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); + void init_queues() { + _q_in_order = + create_queue_impl(true, sycl::property::queue::in_order()); _q_out_of_order = create_queue_impl(true); _saved_queue = default_queue(); } - /// Caller should acquire resource \p m_mutex before calling this function. + /// Caller should acquire resource \p m_mutex before calling this + /// function. template sycl::queue create_queue_impl(bool enable_exception_handler, - Properties... properties) - { + Properties... properties) { sycl::async_handler eh = {}; - if (enable_exception_handler) - { + if (enable_exception_handler) { eh = exception_handler; } - auto q = sycl::queue( - *this, eh, + auto q = sycl::queue(*this, eh, sycl::property_list( #ifdef DPCT_PROFILING_ENABLED sycl::property::queue::enable_profiling(), @@ -818,19 +783,18 @@ namespace dpct if (enable_exception_handler) { eh = exception_handler; } - _queues.push_back(sycl::queue( - device, eh, - sycl::property_list( + _queues.push_back( + sycl::queue(device, eh, + sycl::property_list( #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), + sycl::property::queue::enable_profiling(), #endif - properties...))); + properties...))); return _queues.back(); } - void get_version(int &major, int &minor) const - { + void get_version(int &major, int &minor) const { detail::get_version(*this, major, minor); } sycl::queue _q_in_order, _q_out_of_order; @@ -929,15 +893,15 @@ namespace dpct sycl::backend backend1 = device1.get_backend(); sycl::backend backend2 = device2.get_backend(); // levelzero backends always come first - if (backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true; - if (backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false; + if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true; + if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false; dpct::device_info prop1; dpct::get_device_info(prop1, device1); dpct::device_info prop2; dpct::get_device_info(prop2, device2); return prop1.get_max_compute_units() > prop2.get_max_compute_units(); } - static int convert_backend_index(std::string &backend) { + static int convert_backend_index(std::string & backend) { if (backend == "ext_oneapi_level_zero:gpu") return 0; if (backend == "opencl:gpu") return 1; if (backend == "ext_oneapi_cuda:gpu") return 2; @@ -977,7 +941,7 @@ namespace dpct } std::vector keys; - for (auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { + for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { keys.push_back(it->first); } std::sort(keys.begin(), keys.end(), compare_backend); @@ -1132,7 +1096,7 @@ namespace dpct // Allocation sycl::range<1> r(size); buffer_t buf(r); - allocation A{ buf, next_free, size }; + allocation A{buf, next_free, size}; // Map allocation to device pointer void *result = next_free; m_map.emplace(next_free + size, A); @@ -1242,14 +1206,14 @@ namespace dpct } /** - * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] dev_ptr Pointer to the virtual device memory address. - * @param [in] value The value to be set. - * @param [in] size Number of elements to be set to the value. - * @return An event representing the memset operation. - */ + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ template static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, valueT value, size_t size) @@ -1258,14 +1222,14 @@ namespace dpct } /** - * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] data Pointer to the pitched device memory region. - * @param [in] value The value to be set. - * @param [in] size 3D memory region by number of elements. - * @return An event list representing the memset operations. - */ + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ template static inline std::vector dpct_memset(sycl::queue &q, pitched_data data, valueT value, @@ -1288,16 +1252,16 @@ namespace dpct } /** - * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] ptr Pointer to the virtual device memory. - * @param [in] pitch The pitch size by number of elements, including padding. - * @param [in] val The value to be set. - * @param [in] x The width of memory region by number of elements. - * @param [in] y The height of memory region by number of elements. - * @return An event list representing the memset operations. - */ + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ template static inline std::vector dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, @@ -1319,20 +1283,20 @@ namespace dpct case memcpy_direction::device_to_device: return dir; case memcpy_direction::automatic: - { + { // table[to_attribute][from_attribute] static const memcpy_direction direction_table[static_cast(pointer_access_attribute::end)] [static_cast(pointer_access_attribute::end)] = - { {memcpy_direction::host_to_host, - memcpy_direction::device_to_host, - memcpy_direction::host_to_host}, - {memcpy_direction::host_to_device, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}, - {memcpy_direction::host_to_host, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device} }; + {{memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; return direction_table[static_cast(get_pointer_attribute( q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; } @@ -1411,8 +1375,8 @@ namespace dpct if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { - return { dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events) }; + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; } direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); size_t size_slice = size.get(1) * size.get(0); @@ -1444,7 +1408,7 @@ namespace dpct } break; case host_to_device: - { + { host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, event_list); std::vector host_events; @@ -1475,7 +1439,7 @@ namespace dpct break; } case device_to_host: - { + { host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, event_list); // Copy from host temp buffer to host target with reshaping. @@ -1489,7 +1453,7 @@ namespace dpct break; } case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh) { + event_list.push_back(q.submit([&](sycl::handler &cgh){ cgh.depends_on(dep_events); cgh.parallel_for( size, @@ -1746,7 +1710,7 @@ namespace dpct inline unsigned vectorized_binary(unsigned a, unsigned b, const BinaryOperation binary_op) { - sycl::vec v0{ a }, v1{ b }; + sycl::vec v0{a}, v1{b}; auto v2 = v0.as(); auto v3 = v1.as(); auto v4 = @@ -1793,7 +1757,7 @@ namespace dpct template using dot_product_acc_t = - std::conditional_t &&std::is_unsigned_v, + std::conditional_t && std::is_unsigned_v, uint32_t, int32_t>; template @@ -1821,7 +1785,7 @@ namespace dpct template inline T vectorized_min(T a, T b) { - sycl::vec v0{ a }, v1{ b }; + sycl::vec v0{a}, v1{b}; auto v2 = v0.template as(); auto v3 = v1.template as(); auto v4 = sycl::min(v2, v3); @@ -2099,8 +2063,8 @@ namespace dpct if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { - return { dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events) }; + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; } direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); size_t size_slice = size.get(1) * size.get(0); @@ -2132,7 +2096,7 @@ namespace dpct } break; case host_to_device: - { + { host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, event_list); std::vector host_events; @@ -2163,7 +2127,7 @@ namespace dpct break; } case device_to_host: - { + { host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, event_list); // Copy from host temp buffer to host target with reshaping. @@ -2242,134 +2206,134 @@ namespace dpct case detail::get_type_combination_id( library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { + { detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_impl( + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, - a, lda, b, ldb, &beta_half, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } #endif // __INTEL_MKL__ - default: - throw std::runtime_error("the combination of data type is unsupported"); + default: + throw std::runtime_error("the combination of data type is unsupported"); } } // gemm() - /// Computes a batch of matrix-matrix product with general matrices. - /// \param [in] q The queue where the routine should be executed. - /// \param [in] a_trans Specifies the operation applied to A. - /// \param [in] b_trans Specifies the operation applied to B. - /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. - /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. - /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). - /// \param [in] alpha Scaling factor for the matrix-matrix product. - /// \param [in] a Input matrix A. - /// \param [in] a_type Data type of the matrix A. - /// \param [in] lda Leading dimension of A. - /// \param [in] b Input matrix B. - /// \param [in] b_type Data type of the matrix B. - /// \param [in] ldb Leading dimension of B. - /// \param [in] beta Scaling factor for matrix C. - /// \param [in, out] c Input/Output matrix C. - /// \param [in] c_type Data type of the matrix C. - /// \param [in] ldc Leading dimension of C. - /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. - /// \param [in] scaling_type Data type of the scaling factors. + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void *alpha, const void *a[], @@ -2396,121 +2360,121 @@ namespace dpct case detail::get_type_combination_id( library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { + { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size); break; } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, + a, lda, b, ldb, &beta_float, c, ldc, batch_size); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, + batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); } } @@ -2564,118 +2528,118 @@ namespace dpct case detail::get_type_combination_id( library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): - { + { detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); - break; - } + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, - &beta_half, c, ldc, stride_c, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); } } @@ -2900,7 +2864,7 @@ namespace dpct typename detail::memory_traits::template accessor_t<0>; /// Constructor with initial value. - device_memory(const value_t &val) : base(sycl::range<1>(1), { val }) {} + device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} /// Default constructor device_memory() : base(1) {}