sycl : refactored helper headers into multiple files
This commit is contained in:
parent
44cee5dc89
commit
af514c8d77
13 changed files with 2977 additions and 2931 deletions
|
@ -12,31 +12,32 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <assert.h>
|
||||
#include <atomic>
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <float.h>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <regex>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/half_type.hpp>
|
||||
|
||||
#include <oneapi/mkl.hpp>
|
||||
|
||||
#include "ggml-sycl.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-sycl/backend.hpp"
|
||||
#include "ggml-sycl/dpct/atomic.hpp"
|
||||
#include "ggml-sycl/dpct/blas.hpp"
|
||||
#include "ggml-sycl/dpct/helper.hpp"
|
||||
#include "ggml-sycl/dpct/math.hpp"
|
||||
#include "ggml-sycl/dpct/memory.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
|
||||
bool ggml_sycl_loaded(void);
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
#ifndef GGML_SYCL_COMMON_HPP
|
||||
#define GGML_SYCL_COMMON_HPP
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "dpct/defs.hpp"
|
||||
#include "dpct/device.hpp"
|
||||
#include "dpct/util.hpp"
|
||||
|
||||
#include "ggml-sycl.h"
|
||||
#include "presets.hpp"
|
||||
|
||||
|
@ -34,19 +36,6 @@ static int g_ggml_sycl_debug = 0;
|
|||
fprintf(stderr, __VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_TRY_ERROR(expr) \
|
||||
[&]() { \
|
||||
try { \
|
||||
expr; \
|
||||
return dpct::success; \
|
||||
} catch (std::exception const& e) { \
|
||||
std::cerr << e.what() << "\nException caught at file:" << __FILE__ \
|
||||
<< ", line:" << __LINE__ << ", func:" << __func__ \
|
||||
<< std::endl; \
|
||||
return dpct::default_error; \
|
||||
} \
|
||||
}()
|
||||
|
||||
// #define DEBUG_SYCL_MALLOC
|
||||
|
||||
static int g_work_group_size = 0;
|
||||
|
|
85
ggml/src/ggml-sycl/dpct/atomic.hpp
Normal file
85
ggml/src/ggml-sycl/dpct/atomic.hpp
Normal file
|
@ -0,0 +1,85 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
namespace dpct {
|
||||
|
||||
template <typename T,
|
||||
sycl::access::address_space addressSpace =
|
||||
sycl::access::address_space::global_space,
|
||||
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
|
||||
sycl::memory_scope memoryScope = sycl::memory_scope::device>
|
||||
inline T atomic_fetch_add(T *addr, T operand) {
|
||||
auto atm =
|
||||
sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
|
||||
return atm.fetch_add(operand);
|
||||
}
|
||||
|
||||
template <sycl::access::address_space addressSpace =
|
||||
sycl::access::address_space::global_space,
|
||||
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
|
||||
sycl::memory_scope memoryScope = sycl::memory_scope::device,
|
||||
typename T1, typename T2>
|
||||
inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
|
||||
auto atm =
|
||||
sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
|
||||
return atm.fetch_add(operand);
|
||||
}
|
||||
|
||||
template <typename T, sycl::access::address_space addressSpace =
|
||||
sycl::access::address_space::global_space>
|
||||
inline T atomic_fetch_add(T *addr, T operand,
|
||||
sycl::memory_order memoryOrder) {
|
||||
switch (memoryOrder) {
|
||||
case sycl::memory_order::relaxed:
|
||||
return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
|
||||
sycl::memory_scope::device>(addr, operand);
|
||||
case sycl::memory_order::acq_rel:
|
||||
return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
|
||||
sycl::memory_scope::device>(addr, operand);
|
||||
case sycl::memory_order::seq_cst:
|
||||
return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
|
||||
sycl::memory_scope::device>(addr, operand);
|
||||
default:
|
||||
assert(false && "Invalid memory_order for atomics. Valid memory_order for "
|
||||
"atomics are: sycl::memory_order::relaxed, "
|
||||
"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
|
||||
}
|
||||
}
|
||||
|
||||
template <sycl::access::address_space addressSpace =
|
||||
sycl::access::address_space::global_space,
|
||||
typename T1, typename T2>
|
||||
inline T1 atomic_fetch_add(T1 *addr, T2 operand,
|
||||
sycl::memory_order memoryOrder) {
|
||||
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
|
||||
}
|
||||
|
||||
}
|
585
ggml/src/ggml-sycl/dpct/blas.hpp
Normal file
585
ggml/src/ggml-sycl/dpct/blas.hpp
Normal file
|
@ -0,0 +1,585 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <oneapi/mkl.hpp>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
namespace dpct {
|
||||
|
||||
enum class library_data_t : unsigned char {
|
||||
real_float = 0,
|
||||
complex_float,
|
||||
real_double,
|
||||
complex_double,
|
||||
real_half,
|
||||
complex_half,
|
||||
real_bfloat16,
|
||||
complex_bfloat16,
|
||||
real_int4,
|
||||
complex_int4,
|
||||
real_uint4,
|
||||
complex_uint4,
|
||||
real_int8,
|
||||
complex_int8,
|
||||
real_uint8,
|
||||
complex_uint8,
|
||||
real_int16,
|
||||
complex_int16,
|
||||
real_uint16,
|
||||
complex_uint16,
|
||||
real_int32,
|
||||
complex_int32,
|
||||
real_uint32,
|
||||
complex_uint32,
|
||||
real_int64,
|
||||
complex_int64,
|
||||
real_uint64,
|
||||
complex_uint64,
|
||||
real_int8_4,
|
||||
real_int8_32,
|
||||
real_uint8_4,
|
||||
library_data_t_size
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class Ta, class Tb, class Tc, class Ts>
|
||||
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void **a, int lda,
|
||||
const void **b, int ldb, const void *beta, void **c,
|
||||
int ldc, int batch_size) {
|
||||
struct matrix_info_t {
|
||||
oneapi::mkl::transpose transpose_info[2];
|
||||
Ts value_info[2];
|
||||
std::int64_t size_info[3];
|
||||
std::int64_t ld_info[3];
|
||||
std::int64_t groupsize_info;
|
||||
};
|
||||
|
||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||
|
||||
matrix_info_t *matrix_info =
|
||||
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
||||
matrix_info->transpose_info[0] = a_trans;
|
||||
matrix_info->transpose_info[1] = b_trans;
|
||||
matrix_info->value_info[0] = alpha_value;
|
||||
matrix_info->value_info[1] = beta_value;
|
||||
matrix_info->size_info[0] = m;
|
||||
matrix_info->size_info[1] = n;
|
||||
matrix_info->size_info[2] = k;
|
||||
matrix_info->ld_info[0] = lda;
|
||||
matrix_info->ld_info[1] = ldb;
|
||||
matrix_info->ld_info[2] = ldc;
|
||||
matrix_info->groupsize_info = batch_size;
|
||||
|
||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
||||
matrix_info->size_info, matrix_info->size_info + 1,
|
||||
matrix_info->size_info + 2, matrix_info->value_info,
|
||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
||||
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||
|
||||
q.submit([&](sycl::handler &cgh) {
|
||||
cgh.depends_on(e);
|
||||
cgh.host_task([=] { std::free(matrix_info); });
|
||||
});
|
||||
}
|
||||
|
||||
template <class Ta, class Tb, class Tc, class Ts>
|
||||
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void *a, int lda,
|
||||
long long int stride_a, const void *b, int ldb,
|
||||
long long int stride_b, const void *beta, void *c,
|
||||
int ldc, long long int stride_c, int batch_size) {
|
||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||
auto data_a = get_memory<const Ta>(a);
|
||||
auto data_b = get_memory<const Tb>(b);
|
||||
auto data_c = get_memory<Tc>(c);
|
||||
oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a,
|
||||
data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, batch_size);
|
||||
}
|
||||
|
||||
template <typename ArgT>
|
||||
inline constexpr std::uint64_t get_type_combination_id(ArgT Val) {
|
||||
static_assert((unsigned char)library_data_t::library_data_t_size <=
|
||||
std::numeric_limits<unsigned char>::max() &&
|
||||
"library_data_t size exceeds limit.");
|
||||
static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
|
||||
return (std::uint64_t)Val;
|
||||
}
|
||||
|
||||
template <typename FirstT, typename... RestT>
|
||||
inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
|
||||
RestT... RestVal) {
|
||||
static_assert((std::uint8_t)library_data_t::library_data_t_size <=
|
||||
std::numeric_limits<unsigned char>::max() &&
|
||||
"library_data_t size exceeds limit.");
|
||||
static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
|
||||
static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
|
||||
return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
|
||||
}
|
||||
|
||||
template <class Ta, class Tb, class Tc, class Ts>
|
||||
inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void *a, int lda, const void *b,
|
||||
int ldb, const void *beta, void *c, int ldc) {
|
||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||
auto data_a = get_memory<const Ta>(a);
|
||||
auto data_b = get_memory<const Tb>(b);
|
||||
auto data_c = get_memory<Tc>(c);
|
||||
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k,
|
||||
alpha_value, data_a, lda, data_b, ldb,
|
||||
beta_value, data_c, ldc);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void *a, library_data_t a_type,
|
||||
int lda, const void *b, library_data_t b_type, int ldb,
|
||||
const void *beta, void *c, library_data_t c_type, int ldc,
|
||||
library_data_t scaling_type) {
|
||||
if (scaling_type == library_data_t::real_float &&
|
||||
c_type == library_data_t::complex_float) {
|
||||
scaling_type = library_data_t::complex_float;
|
||||
} else if (scaling_type == library_data_t::real_double &&
|
||||
c_type == library_data_t::complex_double) {
|
||||
scaling_type = library_data_t::complex_double;
|
||||
}
|
||||
|
||||
std::uint64_t key =
|
||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||
switch (key) {
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_float, library_data_t::real_float,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_impl<float, float, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_double, library_data_t::real_double,
|
||||
library_data_t::real_double, library_data_t::real_double): {
|
||||
detail::gemm_impl<double, double, double, double>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_float, library_data_t::complex_float,
|
||||
library_data_t::complex_float, library_data_t::complex_float): {
|
||||
detail::gemm_impl<std::complex<float>, std::complex<float>,
|
||||
std::complex<float>, std::complex<float>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_double, library_data_t::complex_double,
|
||||
library_data_t::complex_double, library_data_t::complex_double): {
|
||||
detail::gemm_impl<std::complex<double>, std::complex<double>,
|
||||
std::complex<double>, std::complex<double>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_half): {
|
||||
detail::gemm_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
#ifdef __INTEL_MKL__
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
||||
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_impl<sycl::half, sycl::half, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float): {
|
||||
float alpha_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
||||
float beta_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
||||
sycl::half alpha_half(alpha_value);
|
||||
sycl::half beta_half(beta_value);
|
||||
detail::gemm_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb,
|
||||
&beta_half, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_bfloat16, library_data_t::real_float): {
|
||||
detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||
oneapi::mkl::bfloat16, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32): {
|
||||
float alpha_float =
|
||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
||||
float beta_float =
|
||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
||||
detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
||||
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb,
|
||||
&beta_float, c, ldc);
|
||||
break;
|
||||
}
|
||||
#endif // __INTEL_MKL__
|
||||
default:
|
||||
throw std::runtime_error("the combination of data type is unsupported");
|
||||
}
|
||||
} // gemm()
|
||||
|
||||
/// Computes a batch of matrix-matrix product with general matrices.
|
||||
/// \param [in] q The queue where the routine should be executed.
|
||||
/// \param [in] a_trans Specifies the operation applied to A.
|
||||
/// \param [in] b_trans Specifies the operation applied to B.
|
||||
/// \param [in] m Specifies the number of rows of the matrix op(A) and of the
|
||||
/// matrix C. \param [in] n Specifies the number of columns of the matrix op(B)
|
||||
/// and of the matrix C. \param [in] k Specifies the number of columns of the
|
||||
/// matrix op(A) and the number of rows of the matrix op(B). \param [in] alpha
|
||||
/// Scaling factor for the matrix-matrix product. \param [in] a Input matrix A.
|
||||
/// \param [in] a_type Data type of the matrix A.
|
||||
/// \param [in] lda Leading dimension of A.
|
||||
/// \param [in] b Input matrix B.
|
||||
/// \param [in] b_type Data type of the matrix B.
|
||||
/// \param [in] ldb Leading dimension of B.
|
||||
/// \param [in] beta Scaling factor for matrix C.
|
||||
/// \param [in, out] c Input/Output matrix C.
|
||||
/// \param [in] c_type Data type of the matrix C.
|
||||
/// \param [in] ldc Leading dimension of C.
|
||||
/// \param [in] batch_size Specifies the number of matrix multiply operations to
|
||||
/// perform. \param [in] scaling_type Data type of the scaling factors.
|
||||
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void *a[],
|
||||
library_data_t a_type, int lda, const void *b[],
|
||||
library_data_t b_type, int ldb, const void *beta,
|
||||
void *c[], library_data_t c_type, int ldc,
|
||||
int batch_size, library_data_t scaling_type) {
|
||||
if (scaling_type == library_data_t::real_float &&
|
||||
c_type == library_data_t::complex_float) {
|
||||
scaling_type = library_data_t::complex_float;
|
||||
} else if (scaling_type == library_data_t::real_double &&
|
||||
c_type == library_data_t::complex_double) {
|
||||
scaling_type = library_data_t::complex_double;
|
||||
}
|
||||
|
||||
std::uint64_t key =
|
||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||
switch (key) {
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_float, library_data_t::real_float,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<float, float, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_double, library_data_t::real_double,
|
||||
library_data_t::real_double, library_data_t::real_double): {
|
||||
detail::gemm_batch_impl<double, double, double, double>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_float, library_data_t::complex_float,
|
||||
library_data_t::complex_float, library_data_t::complex_float): {
|
||||
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
||||
std::complex<float>, std::complex<float>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_double, library_data_t::complex_double,
|
||||
library_data_t::complex_double, library_data_t::complex_double): {
|
||||
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
||||
std::complex<double>, std::complex<double>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_half): {
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
#ifdef __INTEL_MKL__
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_bfloat16, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||
oneapi::mkl::bfloat16, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||
float, float>(q, a_trans, b_trans, m, n, k,
|
||||
alpha, a, lda, b, ldb, beta, c,
|
||||
ldc, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32): {
|
||||
float alpha_float =
|
||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
||||
float beta_float =
|
||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
||||
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb,
|
||||
&beta_float, c, ldc, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float): {
|
||||
float alpha_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
||||
float beta_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
||||
sycl::half alpha_half(alpha_value);
|
||||
sycl::half beta_half(beta_value);
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb,
|
||||
&beta_half, c, ldc, batch_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("the combination of data type is unsupported");
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a batch of matrix-matrix product with general matrices.
|
||||
/// \param [in] q The queue where the routine should be executed.
|
||||
/// \param [in] a_trans Specifies the operation applied to A.
|
||||
/// \param [in] b_trans Specifies the operation applied to B.
|
||||
/// \param [in] m Specifies the number of rows of the matrix op(A) and of the
|
||||
/// matrix C. \param [in] n Specifies the number of columns of the matrix op(B)
|
||||
/// and of the matrix C. \param [in] k Specifies the number of columns of the
|
||||
/// matrix op(A) and the number of rows of the matrix op(B). \param [in] alpha
|
||||
/// Scaling factor for the matrix-matrix product. \param [in] a Input matrix A.
|
||||
/// \param [in] a_type Data type of the matrix A.
|
||||
/// \param [in] lda Leading dimension of A.
|
||||
/// \param [in] stride_a Stride between the different A matrices.
|
||||
/// \param [in] b Input matrix B.
|
||||
/// \param [in] b_type Data type of the matrix B.
|
||||
/// \param [in] ldb Leading dimension of B.
|
||||
/// \param [in] stride_b Stride between the different B matrices.
|
||||
/// \param [in] beta Scaling factor for matrix C.
|
||||
/// \param [in, out] c Input/Output matrix C.
|
||||
/// \param [in] c_type Data type of the matrix C.
|
||||
/// \param [in] ldc Leading dimension of C.
|
||||
/// \param [in] stride_c Stride between the different C matrices.
|
||||
/// \param [in] batch_size Specifies the number of matrix multiply operations to
|
||||
/// perform. \param [in] scaling_type Data type of the scaling factors.
|
||||
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
||||
const void *alpha, const void *a, library_data_t a_type,
|
||||
int lda, long long int stride_a, const void *b,
|
||||
library_data_t b_type, int ldb, long long int stride_b,
|
||||
const void *beta, void *c, library_data_t c_type,
|
||||
int ldc, long long int stride_c, int batch_size,
|
||||
library_data_t scaling_type) {
|
||||
if (scaling_type == library_data_t::real_float &&
|
||||
c_type == library_data_t::complex_float) {
|
||||
scaling_type = library_data_t::complex_float;
|
||||
} else if (scaling_type == library_data_t::real_double &&
|
||||
c_type == library_data_t::complex_double) {
|
||||
scaling_type = library_data_t::complex_double;
|
||||
}
|
||||
|
||||
std::uint64_t key =
|
||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||
switch (key) {
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_float, library_data_t::real_float,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<float, float, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_double, library_data_t::real_double,
|
||||
library_data_t::real_double, library_data_t::real_double): {
|
||||
detail::gemm_batch_impl<double, double, double, double>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_float, library_data_t::complex_float,
|
||||
library_data_t::complex_float, library_data_t::complex_float): {
|
||||
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
||||
std::complex<float>, std::complex<float>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::complex_double, library_data_t::complex_double,
|
||||
library_data_t::complex_double, library_data_t::complex_double): {
|
||||
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
||||
std::complex<double>, std::complex<double>>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_half): {
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#ifdef __INTEL_MKL__
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_bfloat16, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||
oneapi::mkl::bfloat16, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
||||
float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32): {
|
||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
||||
std::int32_t>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_float, library_data_t::real_float): {
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb,
|
||||
stride_b, beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float): {
|
||||
float alpha_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(alpha), q);
|
||||
float beta_value =
|
||||
dpct::get_value(reinterpret_cast<const float *>(beta), q);
|
||||
sycl::half alpha_half(alpha_value);
|
||||
sycl::half beta_half(beta_value);
|
||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb,
|
||||
stride_b, &beta_half, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("the combination of data type is unsupported");
|
||||
}
|
||||
}
|
||||
} // namespace dpct
|
65
ggml/src/ggml-sycl/dpct/common.hpp
Normal file
65
ggml/src/ggml-sycl/dpct/common.hpp
Normal file
|
@ -0,0 +1,65 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "memory.hpp"
|
||||
|
||||
namespace dpct {
|
||||
namespace detail {
|
||||
|
||||
template <typename T> struct DataType { using T2 = T; };
|
||||
|
||||
template <typename T> struct DataType<sycl::vec<T, 2>> {
|
||||
using T2 = std::complex<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q) {
|
||||
using Ty = typename DataType<T>::T2;
|
||||
Ty s_h;
|
||||
if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) {
|
||||
detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T),
|
||||
device_to_host)
|
||||
.wait();
|
||||
} else {
|
||||
s_h = *reinterpret_cast<const Ty *>(s);
|
||||
}
|
||||
return s_h;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T> inline auto get_value(const T *s, sycl::queue &q) {
|
||||
return detail::get_value(s, q);
|
||||
}
|
||||
|
||||
} // namespace dpct
|
61
ggml/src/ggml-sycl/dpct/defs.hpp
Normal file
61
ggml/src/ggml-sycl/dpct/defs.hpp
Normal file
|
@ -0,0 +1,61 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#define DPCT_COMPATIBILITY_TEMP (900)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define __dpct_align__(n) __declspec(align(n))
|
||||
#define __dpct_inline__ __forceinline
|
||||
#else
|
||||
#define __dpct_align__(n) __attribute__((aligned(n)))
|
||||
#define __dpct_inline__ __inline__ __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define __dpct_noinline__ __declspec(noinline)
|
||||
#else
|
||||
#define __dpct_noinline__ __attribute__((noinline))
|
||||
#endif
|
||||
|
||||
namespace dpct {
|
||||
enum error_code { success = 0, default_error = 999 };
|
||||
}
|
||||
|
||||
#define CHECK_TRY_ERROR(expr) \
|
||||
[&]() { \
|
||||
try { \
|
||||
expr; \
|
||||
return dpct::success; \
|
||||
} catch (std::exception const &e) { \
|
||||
std::cerr << e.what() << "\nException caught at file:" << __FILE__ \
|
||||
<< ", line:" << __LINE__ << ", func:" << __func__ \
|
||||
<< std::endl; \
|
||||
return dpct::default_error; \
|
||||
} \
|
||||
}()
|
856
ggml/src/ggml-sycl/dpct/device.hpp
Normal file
856
ggml/src/ggml-sycl/dpct/device.hpp
Normal file
|
@ -0,0 +1,856 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#if defined(_WIN64)
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include "helper.hpp"
|
||||
|
||||
namespace dpct {
|
||||
|
||||
typedef sycl::queue *queue_ptr;
|
||||
typedef sycl::event *event_ptr;
|
||||
typedef char *device_ptr;
|
||||
|
||||
/// SYCL default exception handler
|
||||
inline auto exception_handler = [](sycl::exception_list exceptions) {
|
||||
for (std::exception_ptr const &e : exceptions) {
|
||||
try {
|
||||
std::rethrow_exception(e);
|
||||
} catch (sycl::exception const &e) {
|
||||
std::cerr << "Caught asynchronous SYCL exception:" << std::endl
|
||||
<< e.what() << std::endl
|
||||
<< "Exception caught at file:" << __FILE__
|
||||
<< ", line:" << __LINE__ << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
static void get_version(const sycl::device &dev, int &major, int &minor) {
|
||||
// Version string has the following format:
|
||||
// a. OpenCL<space><major.minor><space><vendor-specific-information>
|
||||
// b. <major.minor>
|
||||
// c. <AmdGcnArchName> e.g gfx1030
|
||||
std::string ver;
|
||||
ver = dev.get_info<sycl::info::device::version>();
|
||||
std::string::size_type i = 0;
|
||||
while (i < ver.size()) {
|
||||
if (isdigit(ver[i]))
|
||||
break;
|
||||
i++;
|
||||
}
|
||||
major = std::stoi(&(ver[i]));
|
||||
while (i < ver.size()) {
|
||||
if (ver[i] == '.')
|
||||
break;
|
||||
i++;
|
||||
}
|
||||
if (i < ver.size()) {
|
||||
// a. and b.
|
||||
i++;
|
||||
minor = std::stoi(&(ver[i]));
|
||||
} else {
|
||||
// c.
|
||||
minor = 0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class device_info {
|
||||
public:
|
||||
// get interface
|
||||
const char *get_name() const { return _name; }
|
||||
char *get_name() { return _name; }
|
||||
template <
|
||||
typename WorkItemSizesTy = sycl::range<3>,
|
||||
std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
|
||||
std::is_same_v<WorkItemSizesTy, int *>,
|
||||
int> = 0>
|
||||
auto get_max_work_item_sizes() const {
|
||||
if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
|
||||
return sycl::range<3>(_max_work_item_sizes_i[0],
|
||||
_max_work_item_sizes_i[1],
|
||||
_max_work_item_sizes_i[2]);
|
||||
else {
|
||||
return _max_work_item_sizes_i;
|
||||
}
|
||||
}
|
||||
template <
|
||||
typename WorkItemSizesTy = sycl::range<3>,
|
||||
std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
|
||||
std::is_same_v<WorkItemSizesTy, int *>,
|
||||
int> = 0>
|
||||
auto get_max_work_item_sizes() {
|
||||
if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
|
||||
return sycl::range<3>(_max_work_item_sizes_i[0],
|
||||
_max_work_item_sizes_i[1],
|
||||
_max_work_item_sizes_i[2]);
|
||||
else {
|
||||
return _max_work_item_sizes_i;
|
||||
}
|
||||
}
|
||||
bool get_host_unified_memory() const { return _host_unified_memory; }
|
||||
int get_major_version() const { return _major; }
|
||||
int get_minor_version() const { return _minor; }
|
||||
int get_integrated() const { return _integrated; }
|
||||
int get_max_clock_frequency() const { return _frequency; }
|
||||
int get_max_compute_units() const { return _max_compute_units; }
|
||||
int get_max_work_group_size() const { return _max_work_group_size; }
|
||||
int get_max_sub_group_size() const { return _max_sub_group_size; }
|
||||
int get_max_work_items_per_compute_unit() const {
|
||||
return _max_work_items_per_compute_unit;
|
||||
}
|
||||
int get_max_register_size_per_work_group() const {
|
||||
return _max_register_size_per_work_group;
|
||||
}
|
||||
template <typename NDRangeSizeTy = size_t *,
|
||||
std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
|
||||
std::is_same_v<NDRangeSizeTy, int *>,
|
||||
int> = 0>
|
||||
auto get_max_nd_range_size() const {
|
||||
if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
|
||||
return _max_nd_range_size;
|
||||
else
|
||||
return _max_nd_range_size_i;
|
||||
}
|
||||
template <typename NDRangeSizeTy = size_t *,
|
||||
std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
|
||||
std::is_same_v<NDRangeSizeTy, int *>,
|
||||
int> = 0>
|
||||
auto get_max_nd_range_size() {
|
||||
if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
|
||||
return _max_nd_range_size;
|
||||
else
|
||||
return _max_nd_range_size_i;
|
||||
}
|
||||
size_t get_global_mem_size() const { return _global_mem_size; }
|
||||
size_t get_local_mem_size() const { return _local_mem_size; }
|
||||
size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
|
||||
/// Returns the maximum clock rate of device's global memory in kHz. If
|
||||
/// compiler does not support this API then returns default value 3200000
|
||||
/// kHz.
|
||||
unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
|
||||
/// Returns the maximum bus width between device and memory in bits. If
|
||||
/// compiler does not support this API then returns default value 64 bits.
|
||||
unsigned int get_memory_bus_width() const { return _memory_bus_width; }
|
||||
uint32_t get_device_id() const { return _device_id; }
|
||||
std::array<unsigned char, 16> get_uuid() const { return _uuid; }
|
||||
/// Returns global memory cache size in bytes.
|
||||
unsigned int get_global_mem_cache_size() const {
|
||||
return _global_mem_cache_size;
|
||||
}
|
||||
|
||||
// set interface
|
||||
void set_name(const char *name) {
|
||||
size_t length = strlen(name);
|
||||
if (length < 256) {
|
||||
std::memcpy(_name, name, length + 1);
|
||||
} else {
|
||||
std::memcpy(_name, name, 255);
|
||||
_name[255] = '\0';
|
||||
}
|
||||
}
|
||||
void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) {
|
||||
for (int i = 0; i < 3; ++i)
|
||||
_max_work_item_sizes_i[i] = max_work_item_sizes[i];
|
||||
}
|
||||
[[deprecated]] void
|
||||
set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
_max_work_item_sizes_i[i] = max_work_item_sizes[i];
|
||||
}
|
||||
}
|
||||
void set_host_unified_memory(bool host_unified_memory) {
|
||||
_host_unified_memory = host_unified_memory;
|
||||
}
|
||||
void set_major_version(int major) { _major = major; }
|
||||
void set_minor_version(int minor) { _minor = minor; }
|
||||
void set_integrated(int integrated) { _integrated = integrated; }
|
||||
void set_max_clock_frequency(int frequency) { _frequency = frequency; }
|
||||
void set_max_compute_units(int max_compute_units) {
|
||||
_max_compute_units = max_compute_units;
|
||||
}
|
||||
void set_global_mem_size(size_t global_mem_size) {
|
||||
_global_mem_size = global_mem_size;
|
||||
}
|
||||
void set_local_mem_size(size_t local_mem_size) {
|
||||
_local_mem_size = local_mem_size;
|
||||
}
|
||||
void set_max_mem_alloc_size(size_t max_mem_alloc_size) {
|
||||
_max_mem_alloc_size = max_mem_alloc_size;
|
||||
}
|
||||
void set_max_work_group_size(int max_work_group_size) {
|
||||
_max_work_group_size = max_work_group_size;
|
||||
}
|
||||
void set_max_sub_group_size(int max_sub_group_size) {
|
||||
_max_sub_group_size = max_sub_group_size;
|
||||
}
|
||||
void
|
||||
set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) {
|
||||
_max_work_items_per_compute_unit = max_work_items_per_compute_unit;
|
||||
}
|
||||
void set_max_nd_range_size(int max_nd_range_size[]) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
_max_nd_range_size[i] = max_nd_range_size[i];
|
||||
_max_nd_range_size_i[i] = max_nd_range_size[i];
|
||||
}
|
||||
}
|
||||
void set_memory_clock_rate(unsigned int memory_clock_rate) {
|
||||
_memory_clock_rate = memory_clock_rate;
|
||||
}
|
||||
void set_memory_bus_width(unsigned int memory_bus_width) {
|
||||
_memory_bus_width = memory_bus_width;
|
||||
}
|
||||
void
|
||||
set_max_register_size_per_work_group(int max_register_size_per_work_group) {
|
||||
_max_register_size_per_work_group = max_register_size_per_work_group;
|
||||
}
|
||||
void set_device_id(uint32_t device_id) { _device_id = device_id; }
|
||||
void set_uuid(std::array<unsigned char, 16> uuid) {
|
||||
_uuid = std::move(uuid);
|
||||
}
|
||||
void set_global_mem_cache_size(unsigned int global_mem_cache_size) {
|
||||
_global_mem_cache_size = global_mem_cache_size;
|
||||
}
|
||||
|
||||
private:
|
||||
char _name[256];
|
||||
int _max_work_item_sizes_i[3];
|
||||
bool _host_unified_memory = false;
|
||||
int _major;
|
||||
int _minor;
|
||||
int _integrated = 0;
|
||||
int _frequency;
|
||||
// Set estimated value 3200000 kHz as default value.
|
||||
unsigned int _memory_clock_rate = 3200000;
|
||||
// Set estimated value 64 bits as default value.
|
||||
unsigned int _memory_bus_width = 64;
|
||||
unsigned int _global_mem_cache_size;
|
||||
int _max_compute_units;
|
||||
int _max_work_group_size;
|
||||
int _max_sub_group_size;
|
||||
int _max_work_items_per_compute_unit;
|
||||
int _max_register_size_per_work_group;
|
||||
size_t _global_mem_size;
|
||||
size_t _local_mem_size;
|
||||
size_t _max_mem_alloc_size;
|
||||
size_t _max_nd_range_size[3];
|
||||
int _max_nd_range_size_i[3];
|
||||
uint32_t _device_id;
|
||||
std::array<unsigned char, 16> _uuid;
|
||||
};
|
||||
|
||||
static int get_major_version(const sycl::device &dev) {
|
||||
int major, minor;
|
||||
detail::get_version(dev, major, minor);
|
||||
return major;
|
||||
}
|
||||
|
||||
static int get_minor_version(const sycl::device &dev) {
|
||||
int major, minor;
|
||||
detail::get_version(dev, major, minor);
|
||||
return minor;
|
||||
}
|
||||
|
||||
static void get_device_info(device_info &out, const sycl::device &dev) {
|
||||
device_info prop;
|
||||
prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
|
||||
|
||||
int major, minor;
|
||||
detail::get_version(dev, major, minor);
|
||||
prop.set_major_version(major);
|
||||
prop.set_minor_version(minor);
|
||||
|
||||
prop.set_max_work_item_sizes(
|
||||
#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
|
||||
// oneAPI DPC++ compiler older than 2022/09/02, where
|
||||
// max_work_item_sizes is an enum class element
|
||||
dev.get_info<sycl::info::device::max_work_item_sizes>());
|
||||
#else
|
||||
// SYCL 2020-conformant code, max_work_item_sizes is a struct templated
|
||||
// by an int
|
||||
dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
|
||||
#endif
|
||||
prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
|
||||
|
||||
prop.set_max_clock_frequency(
|
||||
dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
|
||||
|
||||
prop.set_max_compute_units(
|
||||
dev.get_info<sycl::info::device::max_compute_units>());
|
||||
prop.set_max_work_group_size(
|
||||
dev.get_info<sycl::info::device::max_work_group_size>());
|
||||
prop.set_global_mem_size(
|
||||
dev.get_info<sycl::info::device::global_mem_size>());
|
||||
prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
|
||||
prop.set_max_mem_alloc_size(
|
||||
dev.get_info<sycl::info::device::max_mem_alloc_size>());
|
||||
|
||||
#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
|
||||
if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) {
|
||||
unsigned int tmp =
|
||||
dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
|
||||
if (tmp != 0)
|
||||
prop.set_memory_clock_rate(1000 * tmp);
|
||||
}
|
||||
if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) {
|
||||
prop.set_memory_bus_width(
|
||||
dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
|
||||
}
|
||||
if (dev.has(sycl::aspect::ext_intel_device_id)) {
|
||||
prop.set_device_id(
|
||||
dev.get_info<sycl::ext::intel::info::device::device_id>());
|
||||
}
|
||||
if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) {
|
||||
prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
|
||||
}
|
||||
#elif defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma message("get_device_info: querying memory_clock_rate and \
|
||||
memory_bus_width are not supported by the compiler used. \
|
||||
Use 3200000 kHz as memory_clock_rate default value. \
|
||||
Use 64 bits as memory_bus_width default value.")
|
||||
#else
|
||||
#warning "get_device_info: querying memory_clock_rate and \
|
||||
memory_bus_width are not supported by the compiler used. \
|
||||
Use 3200000 kHz as memory_clock_rate default value. \
|
||||
Use 64 bits as memory_bus_width default value."
|
||||
#endif
|
||||
|
||||
size_t max_sub_group_size = 1;
|
||||
std::vector<size_t> sub_group_sizes =
|
||||
dev.get_info<sycl::info::device::sub_group_sizes>();
|
||||
|
||||
for (const auto &sub_group_size : sub_group_sizes) {
|
||||
if (max_sub_group_size < sub_group_size)
|
||||
max_sub_group_size = sub_group_size;
|
||||
}
|
||||
|
||||
prop.set_max_sub_group_size(max_sub_group_size);
|
||||
|
||||
prop.set_max_work_items_per_compute_unit(
|
||||
dev.get_info<sycl::info::device::max_work_group_size>());
|
||||
int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
|
||||
prop.set_max_nd_range_size(max_nd_range_size);
|
||||
|
||||
// Estimates max register size per work group, feel free to update the value
|
||||
// according to device properties.
|
||||
prop.set_max_register_size_per_work_group(65536);
|
||||
|
||||
prop.set_global_mem_cache_size(
|
||||
dev.get_info<sycl::info::device::global_mem_cache_size>());
|
||||
out = prop;
|
||||
}
|
||||
|
||||
/// dpct device extension
|
||||
class device_ext : public sycl::device {
|
||||
typedef std::mutex mutex_type;
|
||||
|
||||
public:
|
||||
device_ext() : sycl::device() {}
|
||||
~device_ext() {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
clear_queues();
|
||||
}
|
||||
device_ext(const sycl::device &base) : sycl::device(base) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
init_queues();
|
||||
}
|
||||
|
||||
int is_native_atomic_supported() { return 0; }
|
||||
int get_major_version() const { return dpct::get_major_version(*this); }
|
||||
|
||||
int get_minor_version() const { return dpct::get_minor_version(*this); }
|
||||
|
||||
int get_max_compute_units() const {
|
||||
return get_device_info().get_max_compute_units();
|
||||
}
|
||||
|
||||
/// Return the maximum clock frequency of this device in KHz.
|
||||
int get_max_clock_frequency() const {
|
||||
return get_device_info().get_max_clock_frequency();
|
||||
}
|
||||
|
||||
int get_integrated() const { return get_device_info().get_integrated(); }
|
||||
|
||||
int get_max_sub_group_size() const {
|
||||
return get_device_info().get_max_sub_group_size();
|
||||
}
|
||||
|
||||
int get_max_register_size_per_work_group() const {
|
||||
return get_device_info().get_max_register_size_per_work_group();
|
||||
}
|
||||
|
||||
int get_max_work_group_size() const {
|
||||
return get_device_info().get_max_work_group_size();
|
||||
}
|
||||
|
||||
int get_mem_base_addr_align() const {
|
||||
return get_info<sycl::info::device::mem_base_addr_align>();
|
||||
}
|
||||
|
||||
size_t get_global_mem_size() const {
|
||||
return get_device_info().get_global_mem_size();
|
||||
}
|
||||
|
||||
size_t get_max_mem_alloc_size() const {
|
||||
return get_device_info().get_max_mem_alloc_size();
|
||||
}
|
||||
|
||||
/// Get the number of bytes of free and total memory on the SYCL device.
|
||||
/// \param [out] free_memory The number of bytes of free memory on the
|
||||
/// SYCL device. \param [out] total_memory The number of bytes of total
|
||||
/// memory on the SYCL device.
|
||||
void get_memory_info(size_t &free_memory, size_t &total_memory) {
|
||||
total_memory = get_device_info().get_global_mem_size();
|
||||
const char *warning_info =
|
||||
"get_memory_info: [warning] ext_intel_free_memory is not "
|
||||
"supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
|
||||
"use total memory as free memory";
|
||||
#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
|
||||
if (!has(sycl::aspect::ext_intel_free_memory)) {
|
||||
std::cerr << warning_info << std::endl;
|
||||
free_memory = total_memory;
|
||||
} else {
|
||||
free_memory =
|
||||
get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
}
|
||||
#else
|
||||
std::cerr << warning_info << std::endl;
|
||||
free_memory = total_memory;
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma message("Querying the number of bytes of free memory is not supported")
|
||||
#else
|
||||
#warning "Querying the number of bytes of free memory is not supported"
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
void get_device_info(device_info &out) const {
|
||||
dpct::get_device_info(out, *this);
|
||||
}
|
||||
|
||||
device_info get_device_info() const {
|
||||
device_info prop;
|
||||
dpct::get_device_info(prop, *this);
|
||||
return prop;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
clear_queues();
|
||||
init_queues();
|
||||
}
|
||||
|
||||
sycl::queue &in_order_queue() { return _q_in_order; }
|
||||
|
||||
sycl::queue &out_of_order_queue() { return _q_out_of_order; }
|
||||
|
||||
sycl::queue &default_queue() { return in_order_queue(); }
|
||||
|
||||
void queues_wait_and_throw() {
|
||||
std::unique_lock<mutex_type> lock(m_mutex);
|
||||
lock.unlock();
|
||||
for (auto &q : _queues) {
|
||||
q.wait_and_throw();
|
||||
}
|
||||
// Guard the destruct of current_queues to make sure the ref count is
|
||||
// safe.
|
||||
lock.lock();
|
||||
}
|
||||
|
||||
sycl::queue create_queue(bool enable_exception_handler = false) {
|
||||
return create_in_order_queue(enable_exception_handler);
|
||||
}
|
||||
|
||||
sycl::queue create_queue(sycl::device device,
|
||||
bool enable_exception_handler = false) {
|
||||
return create_in_order_queue(device, enable_exception_handler);
|
||||
}
|
||||
|
||||
sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
return create_queue_impl(enable_exception_handler,
|
||||
sycl::property::queue::in_order());
|
||||
}
|
||||
|
||||
sycl::queue create_in_order_queue(sycl::device device,
|
||||
bool enable_exception_handler = false) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
return create_queue_impl(device, enable_exception_handler,
|
||||
sycl::property::queue::in_order());
|
||||
}
|
||||
|
||||
sycl::queue
|
||||
create_out_of_order_queue(bool enable_exception_handler = false) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
return create_queue_impl(enable_exception_handler);
|
||||
}
|
||||
|
||||
void destroy_queue(sycl::queue queue) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
_queues.clear();
|
||||
}
|
||||
void set_saved_queue(sycl::queue q) {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
_saved_queue = q;
|
||||
}
|
||||
sycl::queue get_saved_queue() const {
|
||||
std::lock_guard<mutex_type> lock(m_mutex);
|
||||
return _saved_queue;
|
||||
}
|
||||
|
||||
private:
|
||||
void clear_queues() { _queues.clear(); }
|
||||
|
||||
void init_queues() {
|
||||
_q_in_order =
|
||||
create_queue_impl(true, sycl::property::queue::in_order());
|
||||
_q_out_of_order = create_queue_impl(true);
|
||||
_saved_queue = default_queue();
|
||||
}
|
||||
|
||||
/// Caller should acquire resource \p m_mutex before calling this
|
||||
/// function.
|
||||
template <class... Properties>
|
||||
sycl::queue create_queue_impl(bool enable_exception_handler,
|
||||
Properties... properties) {
|
||||
sycl::async_handler eh = {};
|
||||
if (enable_exception_handler) {
|
||||
eh = exception_handler;
|
||||
}
|
||||
auto q = sycl::queue(*this, eh,
|
||||
sycl::property_list(
|
||||
#ifdef DPCT_PROFILING_ENABLED
|
||||
sycl::property::queue::enable_profiling(),
|
||||
#endif
|
||||
properties...));
|
||||
_queues.push_back(q);
|
||||
|
||||
return _queues.back();
|
||||
}
|
||||
|
||||
template <class... Properties>
|
||||
sycl::queue create_queue_impl(sycl::device device,
|
||||
bool enable_exception_handler,
|
||||
Properties... properties) {
|
||||
sycl::async_handler eh = {};
|
||||
if (enable_exception_handler) {
|
||||
eh = exception_handler;
|
||||
}
|
||||
_queues.push_back(
|
||||
sycl::queue(device, eh,
|
||||
sycl::property_list(
|
||||
#ifdef DPCT_PROFILING_ENABLED
|
||||
sycl::property::queue::enable_profiling(),
|
||||
#endif
|
||||
properties...)));
|
||||
|
||||
return _queues.back();
|
||||
}
|
||||
|
||||
void get_version(int &major, int &minor) const {
|
||||
detail::get_version(*this, major, minor);
|
||||
}
|
||||
sycl::queue _q_in_order, _q_out_of_order;
|
||||
sycl::queue _saved_queue;
|
||||
std::vector<sycl::queue> _queues;
|
||||
mutable mutex_type m_mutex;
|
||||
};
|
||||
|
||||
static inline unsigned int get_tid() {
|
||||
#if defined(__linux__)
|
||||
return syscall(SYS_gettid);
|
||||
#elif defined(_WIN64)
|
||||
return GetCurrentThreadId();
|
||||
#else
|
||||
#error "Only support Windows and Linux."
|
||||
#endif
|
||||
}
|
||||
|
||||
/// device manager
|
||||
class dev_mgr {
|
||||
public:
|
||||
device_ext ¤t_device() {
|
||||
unsigned int dev_id = current_device_id();
|
||||
check_id(dev_id);
|
||||
return *_devs[dev_id];
|
||||
}
|
||||
device_ext &cpu_device() const {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
if (_cpu_device == -1) {
|
||||
throw std::runtime_error("no valid cpu device");
|
||||
} else {
|
||||
return *_devs[_cpu_device];
|
||||
}
|
||||
}
|
||||
device_ext &get_device(unsigned int id) const {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
check_id(id);
|
||||
return *_devs[id];
|
||||
}
|
||||
unsigned int current_device_id() const {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
auto it = _thread2dev_map.find(get_tid());
|
||||
if (it != _thread2dev_map.end())
|
||||
return it->second;
|
||||
return DEFAULT_DEVICE_ID;
|
||||
}
|
||||
|
||||
/// Select device with a device ID.
|
||||
/// \param [in] id The id of the device which can
|
||||
/// be obtained through get_device_id(const sycl::device).
|
||||
void select_device(unsigned int id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
check_id(id);
|
||||
_thread2dev_map[get_tid()] = id;
|
||||
}
|
||||
unsigned int device_count() { return _devs.size(); }
|
||||
|
||||
unsigned int get_device_id(const sycl::device &dev) {
|
||||
unsigned int id = 0;
|
||||
for (auto dev_item : _devs) {
|
||||
if (*dev_item == dev) {
|
||||
break;
|
||||
}
|
||||
id++;
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
template <class DeviceSelector>
|
||||
std::enable_if_t<
|
||||
std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
|
||||
select_device(const DeviceSelector &selector = sycl::gpu_selector_v) {
|
||||
sycl::device selected_device = sycl::device(selector);
|
||||
unsigned int selected_device_id = get_device_id(selected_device);
|
||||
select_device(selected_device_id);
|
||||
}
|
||||
|
||||
/// Returns the instance of device manager singleton.
|
||||
static dev_mgr &instance() {
|
||||
static dev_mgr d_m;
|
||||
return d_m;
|
||||
}
|
||||
dev_mgr(const dev_mgr &) = delete;
|
||||
dev_mgr &operator=(const dev_mgr &) = delete;
|
||||
dev_mgr(dev_mgr &&) = delete;
|
||||
dev_mgr &operator=(dev_mgr &&) = delete;
|
||||
|
||||
private:
|
||||
mutable std::recursive_mutex m_mutex;
|
||||
static bool compare_dev(sycl::device &device1, sycl::device &device2) {
|
||||
sycl::backend backend1 = device1.get_backend();
|
||||
sycl::backend backend2 = device2.get_backend();
|
||||
// levelzero backends always come first
|
||||
if (backend1 == sycl::backend::ext_oneapi_level_zero &&
|
||||
backend2 != sycl::backend::ext_oneapi_level_zero)
|
||||
return true;
|
||||
if (backend1 != sycl::backend::ext_oneapi_level_zero &&
|
||||
backend2 == sycl::backend::ext_oneapi_level_zero)
|
||||
return false;
|
||||
dpct::device_info prop1;
|
||||
dpct::get_device_info(prop1, device1);
|
||||
dpct::device_info prop2;
|
||||
dpct::get_device_info(prop2, device2);
|
||||
return prop1.get_max_compute_units() > prop2.get_max_compute_units();
|
||||
}
|
||||
static int convert_backend_index(std::string &backend) {
|
||||
if (backend == "ext_oneapi_level_zero:gpu")
|
||||
return 0;
|
||||
if (backend == "opencl:gpu")
|
||||
return 1;
|
||||
if (backend == "ext_oneapi_cuda:gpu")
|
||||
return 2;
|
||||
if (backend == "ext_oneapi_hip:gpu")
|
||||
return 3;
|
||||
if (backend == "opencl:cpu")
|
||||
return 4;
|
||||
if (backend == "opencl:acc")
|
||||
return 5;
|
||||
printf("convert_backend_index: can't handle backend=%s\n",
|
||||
backend.c_str());
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
||||
return convert_backend_index(backend1) <
|
||||
convert_backend_index(backend2);
|
||||
}
|
||||
dev_mgr() {
|
||||
sycl::device default_device = sycl::device(sycl::default_selector_v);
|
||||
_devs.push_back(std::make_shared<device_ext>(default_device));
|
||||
|
||||
std::vector<sycl::device> sycl_all_devs;
|
||||
// Collect other devices except for the default device.
|
||||
if (default_device.is_cpu())
|
||||
_cpu_device = 0;
|
||||
|
||||
auto Platforms = sycl::platform::get_platforms();
|
||||
// Keep track of the number of devices per backend
|
||||
std::map<sycl::backend, size_t> DeviceNums;
|
||||
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
||||
|
||||
while (!Platforms.empty()) {
|
||||
auto Platform = Platforms.back();
|
||||
Platforms.pop_back();
|
||||
auto devices = Platform.get_devices();
|
||||
std::string backend_type = get_device_backend_and_type(devices[0]);
|
||||
for (const auto &device : devices) {
|
||||
backend_devices[backend_type].push_back(device);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> keys;
|
||||
for (auto it = backend_devices.begin(); it != backend_devices.end();
|
||||
++it) {
|
||||
keys.push_back(it->first);
|
||||
}
|
||||
std::sort(keys.begin(), keys.end(), compare_backend);
|
||||
|
||||
for (auto &key : keys) {
|
||||
std::vector<sycl::device> devs = backend_devices[key];
|
||||
std::sort(devs.begin(), devs.end(), compare_dev);
|
||||
for (const auto &dev : devs) {
|
||||
sycl_all_devs.push_back(dev);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &dev : sycl_all_devs) {
|
||||
if (dev == default_device) {
|
||||
continue;
|
||||
}
|
||||
_devs.push_back(std::make_shared<device_ext>(dev));
|
||||
if (_cpu_device == -1 && dev.is_cpu()) {
|
||||
_cpu_device = _devs.size() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
void check_id(unsigned int id) const {
|
||||
if (id >= _devs.size()) {
|
||||
throw std::runtime_error("invalid device id");
|
||||
}
|
||||
}
|
||||
std::vector<std::shared_ptr<device_ext>> _devs;
|
||||
/// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
|
||||
/// thread id in _thread2dev_map, which means default device should be used
|
||||
/// for the current thread.
|
||||
const unsigned int DEFAULT_DEVICE_ID = 0;
|
||||
/// thread-id to device-id map.
|
||||
std::map<unsigned int, unsigned int> _thread2dev_map;
|
||||
int _cpu_device = -1;
|
||||
};
|
||||
|
||||
static void destroy_event(event_ptr event) { delete event; }
|
||||
|
||||
static inline sycl::queue &get_default_queue() {
|
||||
return dev_mgr::instance().current_device().default_queue();
|
||||
}
|
||||
|
||||
static inline unsigned int select_device(unsigned int id) {
|
||||
dev_mgr::instance().select_device(id);
|
||||
return id;
|
||||
}
|
||||
|
||||
inline void
|
||||
has_capability_or_fail(const sycl::device &dev,
|
||||
const std::initializer_list<sycl::aspect> &props) {
|
||||
for (const auto &it : props) {
|
||||
if (dev.has(it))
|
||||
continue;
|
||||
switch (it) {
|
||||
case sycl::aspect::fp64:
|
||||
throw std::runtime_error("'double' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() +
|
||||
"' device");
|
||||
break;
|
||||
case sycl::aspect::fp16:
|
||||
throw std::runtime_error("'half' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() +
|
||||
"' device");
|
||||
break;
|
||||
default:
|
||||
#define __SYCL_ASPECT(ASPECT, ID) \
|
||||
case sycl::aspect::ASPECT: \
|
||||
return #ASPECT;
|
||||
#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
|
||||
#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
|
||||
auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string {
|
||||
switch (AspectNum) {
|
||||
#include <sycl/info/aspects.def>
|
||||
#include <sycl/info/aspects_deprecated.def>
|
||||
default:
|
||||
return "unknown aspect";
|
||||
}
|
||||
};
|
||||
#undef __SYCL_ASPECT_DEPRECATED_ALIAS
|
||||
#undef __SYCL_ASPECT_DEPRECATED
|
||||
#undef __SYCL_ASPECT
|
||||
throw std::runtime_error(
|
||||
"'" + getAspectNameStr(it) + "' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() + "' device");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static inline unsigned int get_current_device_id() {
|
||||
return dev_mgr::instance().current_device_id();
|
||||
}
|
||||
|
||||
static inline device_ext &get_current_device() {
|
||||
return dev_mgr::instance().current_device();
|
||||
}
|
||||
|
||||
static inline sycl::queue &get_in_order_queue() {
|
||||
return dev_mgr::instance().current_device().in_order_queue();
|
||||
}
|
||||
|
||||
} // namespace dpct
|
File diff suppressed because it is too large
Load diff
237
ggml/src/ggml-sycl/dpct/math.hpp
Normal file
237
ggml/src/ggml-sycl/dpct/math.hpp
Normal file
|
@ -0,0 +1,237 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
namespace dpct {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename VecT, class BinaryOperation, class = void>
|
||||
class vectorized_binary {
|
||||
public:
|
||||
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
|
||||
VecT v4;
|
||||
for (size_t i = 0; i < v4.size(); ++i) {
|
||||
v4[i] = binary_op(a[i], b[i]);
|
||||
}
|
||||
return v4;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename VecT, class BinaryOperation>
|
||||
class vectorized_binary<
|
||||
VecT, BinaryOperation,
|
||||
std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>> {
|
||||
public:
|
||||
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
|
||||
return binary_op(a, b).template as<VecT>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T> sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
|
||||
return sycl::vec<T, 1>(val)
|
||||
.template as<sycl::vec<
|
||||
std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
|
||||
.template convert<T>();
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
using dot_product_acc_t =
|
||||
std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
|
||||
uint32_t, int32_t>;
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline auto dp4a(T1 a, T2 b, T3 c) {
|
||||
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
|
||||
defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 610
|
||||
dot_product_acc_t<T1, T2> res;
|
||||
if constexpr (std::is_same_v<dot_product_acc_t<T1, T2>, uint32_t>) {
|
||||
asm volatile("dp4a.u32.u32 %0, %1, %2, %3;"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
} else {
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
return res;
|
||||
#else
|
||||
dot_product_acc_t<T1, T2> res = c;
|
||||
auto va = extract_and_sign_or_zero_extend4(a);
|
||||
auto vb = extract_and_sign_or_zero_extend4(b);
|
||||
res += va[0] * vb[0];
|
||||
res += va[1] * vb[1];
|
||||
res += va[2] * vb[2];
|
||||
res += va[3] * vb[3];
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct sub_sat {
|
||||
template <typename T> auto operator()(const T x, const T y) const {
|
||||
return sycl::sub_sat(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename S, typename T> inline T vectorized_min(T a, T b) {
|
||||
sycl::vec<T, 1> v0{a}, v1{b};
|
||||
auto v2 = v0.template as<S>();
|
||||
auto v3 = v1.template as<S>();
|
||||
auto v4 = sycl::min(v2, v3);
|
||||
v0 = v4.template as<sycl::vec<T, 1>>();
|
||||
return v0;
|
||||
}
|
||||
|
||||
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
|
||||
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
|
||||
inline float pow(const float a, const float b) { return sycl::pow(a, b); }
|
||||
inline double pow(const double a, const double b) { return sycl::pow(a, b); }
|
||||
template <typename T, typename U>
|
||||
inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
|
||||
pow(const T a, const U b) {
|
||||
return sycl::pow(a, static_cast<T>(b));
|
||||
}
|
||||
template <typename T, typename U>
|
||||
inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
|
||||
pow(const T a, const U b) {
|
||||
return sycl::pow(static_cast<double>(a), static_cast<double>(b));
|
||||
}
|
||||
|
||||
inline double min(const double a, const float b) {
|
||||
return sycl::fmin(a, static_cast<double>(b));
|
||||
}
|
||||
inline double min(const float a, const double b) {
|
||||
return sycl::fmin(static_cast<double>(a), b);
|
||||
}
|
||||
inline float min(const float a, const float b) { return sycl::fmin(a, b); }
|
||||
inline double min(const double a, const double b) { return sycl::fmin(a, b); }
|
||||
inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) {
|
||||
return sycl::min(a, static_cast<std::uint32_t>(b));
|
||||
}
|
||||
inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) {
|
||||
return sycl::min(static_cast<std::uint32_t>(a), b);
|
||||
}
|
||||
inline std::int32_t min(const std::int32_t a, const std::int32_t b) {
|
||||
return sycl::min(a, b);
|
||||
}
|
||||
inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) {
|
||||
return sycl::min(a, b);
|
||||
}
|
||||
inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) {
|
||||
return sycl::min(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) {
|
||||
return sycl::min(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
inline std::int64_t min(const std::int64_t a, const std::int64_t b) {
|
||||
return sycl::min(a, b);
|
||||
}
|
||||
inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) {
|
||||
return sycl::min(a, b);
|
||||
}
|
||||
inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) {
|
||||
return sycl::min(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) {
|
||||
return sycl::min(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) {
|
||||
return sycl::min(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) {
|
||||
return sycl::min(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
// max function overloads.
|
||||
// For floating-point types, `float` or `double` arguments are acceptable.
|
||||
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
|
||||
// `std::int64_t` type arguments are acceptable.
|
||||
inline double max(const double a, const float b) {
|
||||
return sycl::fmax(a, static_cast<double>(b));
|
||||
}
|
||||
inline double max(const float a, const double b) {
|
||||
return sycl::fmax(static_cast<double>(a), b);
|
||||
}
|
||||
inline float max(const float a, const float b) { return sycl::fmax(a, b); }
|
||||
inline double max(const double a, const double b) { return sycl::fmax(a, b); }
|
||||
inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) {
|
||||
return sycl::max(a, static_cast<std::uint32_t>(b));
|
||||
}
|
||||
inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) {
|
||||
return sycl::max(static_cast<std::uint32_t>(a), b);
|
||||
}
|
||||
inline std::int32_t max(const std::int32_t a, const std::int32_t b) {
|
||||
return sycl::max(a, b);
|
||||
}
|
||||
inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) {
|
||||
return sycl::max(a, b);
|
||||
}
|
||||
inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) {
|
||||
return sycl::max(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) {
|
||||
return sycl::max(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
inline std::int64_t max(const std::int64_t a, const std::int64_t b) {
|
||||
return sycl::max(a, b);
|
||||
}
|
||||
inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) {
|
||||
return sycl::max(a, b);
|
||||
}
|
||||
inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) {
|
||||
return sycl::max(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) {
|
||||
return sycl::max(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) {
|
||||
return sycl::max(a, static_cast<std::uint64_t>(b));
|
||||
}
|
||||
inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) {
|
||||
return sycl::max(static_cast<std::uint64_t>(a), b);
|
||||
}
|
||||
|
||||
template <typename VecT, class BinaryOperation>
|
||||
inline unsigned vectorized_binary(unsigned a, unsigned b,
|
||||
const BinaryOperation binary_op) {
|
||||
sycl::vec<unsigned, 1> v0{a}, v1{b};
|
||||
auto v2 = v0.as<VecT>();
|
||||
auto v3 = v1.as<VecT>();
|
||||
auto v4 =
|
||||
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
|
||||
v0 = v4.template as<sycl::vec<unsigned, 1>>();
|
||||
return v0;
|
||||
}
|
||||
|
||||
} // namespace dpct
|
1006
ggml/src/ggml-sycl/dpct/memory.hpp
Normal file
1006
ggml/src/ggml-sycl/dpct/memory.hpp
Normal file
File diff suppressed because it is too large
Load diff
64
ggml/src/ggml-sycl/dpct/util.hpp
Normal file
64
ggml/src/ggml-sycl/dpct/util.hpp
Normal file
|
@ -0,0 +1,64 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
/***************************************************************************
|
||||
*
|
||||
* Copyright (C) Codeplay Software Ltd.
|
||||
*
|
||||
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
**************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
namespace dpct {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename tag, typename T> class generic_error_type {
|
||||
public:
|
||||
generic_error_type() = default;
|
||||
generic_error_type(T value) : value{value} {}
|
||||
operator T() const { return value; }
|
||||
|
||||
private:
|
||||
T value;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
|
||||
unsigned int logical_sub_group_size = 32) {
|
||||
unsigned int id = g.get_local_linear_id();
|
||||
unsigned int start_index =
|
||||
id / logical_sub_group_size * logical_sub_group_size;
|
||||
unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
|
||||
return sycl::select_from_group(g, x,
|
||||
target_offset < logical_sub_group_size
|
||||
? start_index + target_offset
|
||||
: id);
|
||||
}
|
||||
|
||||
using err0 = detail::generic_error_type<struct err0_tag, int>;
|
||||
using err1 = detail::generic_error_type<struct err1_tag, int>;
|
||||
|
||||
} // namespace dpct
|
|
@ -1,5 +1,9 @@
|
|||
#include "rope.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
struct rope_corr_dims {
|
||||
float v[2];
|
||||
};
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#ifndef GGML_SYCL_VECDOTQ_HPP
|
||||
#define GGML_SYCL_VECDOTQ_HPP
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "dpct/math.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue