Formatting
Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
parent
6b77639258
commit
b0f14c5c2a
3 changed files with 70 additions and 100 deletions
|
@ -357,9 +357,7 @@ struct ggml_backend_sycl_context {
|
||||||
return *host_pools[device];
|
return *host_pools[device];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_sycl_pool & host_pool() {
|
ggml_sycl_pool & host_pool() { return host_pool(device); }
|
||||||
return host_pool(device);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// common device functions
|
// common device functions
|
||||||
|
|
|
@ -82,9 +82,7 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
||||||
return device_type.str();
|
return device_type.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Ts>
|
template <typename Ts> struct matrix_info_t {
|
||||||
struct matrix_info_t
|
|
||||||
{
|
|
||||||
oneapi::mkl::transpose transpose_info[2];
|
oneapi::mkl::transpose transpose_info[2];
|
||||||
Ts value_info[2];
|
Ts value_info[2];
|
||||||
std::int64_t size_info[3];
|
std::int64_t size_info[3];
|
||||||
|
@ -1737,13 +1735,10 @@ namespace dpct
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
||||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
||||||
const void *alpha, const void **a, int lda,
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
||||||
const void **b, int ldb, const void *beta, void **c,
|
matrix_info_t<float> * matrix_info) {
|
||||||
int ldc, int batch_size, matrix_info_t<float>* matrix_info)
|
|
||||||
{
|
|
||||||
|
|
||||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||||
|
|
||||||
|
@ -1763,19 +1758,18 @@ namespace dpct
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||||
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||||
matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), reinterpret_cast<const Ta **>(a),
|
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
||||||
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
reinterpret_cast<Ts*>(matrix_info->value_info+1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
||||||
&(matrix_info->groupsize_info));
|
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
#else
|
#else
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||||
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
matrix_info->ld_info + 1, reinterpret_cast<Ts*>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
|
@ -2418,15 +2412,11 @@ namespace dpct
|
||||||
/// \param [in] ldc Leading dimension of C.
|
/// \param [in] ldc Leading dimension of C.
|
||||||
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
||||||
/// \param [in] scaling_type Data type of the scaling factors.
|
/// \param [in] scaling_type Data type of the scaling factors.
|
||||||
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
||||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
||||||
const void *alpha, const void *a[],
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
||||||
library_data_t a_type, int lda, const void *b[],
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
||||||
library_data_t b_type, int ldb, const void *beta,
|
matrix_info_t<float> * matrix_info) {
|
||||||
void *c[], library_data_t c_type, int ldc,
|
|
||||||
int batch_size, library_data_t scaling_type,
|
|
||||||
matrix_info_t<float>* matrix_info)
|
|
||||||
{
|
|
||||||
std::uint64_t key =
|
std::uint64_t key =
|
||||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||||
switch (key)
|
switch (key)
|
||||||
|
@ -2435,28 +2425,24 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float,
|
library_data_t::real_float, library_data_t::real_float,
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<float, float, float, float>(
|
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_double, library_data_t::real_double,
|
library_data_t::real_double, library_data_t::real_double,
|
||||||
library_data_t::real_double, library_data_t::real_double):
|
library_data_t::real_double, library_data_t::real_double):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<double, double, double, double>(
|
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_half, library_data_t::real_half,
|
library_data_t::real_half, library_data_t::real_half,
|
||||||
library_data_t::real_half, library_data_t::real_half):
|
library_data_t::real_half, library_data_t::real_half):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||||
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef __INTEL_MKL__
|
#ifdef __INTEL_MKL__
|
||||||
|
@ -2464,19 +2450,16 @@ namespace dpct
|
||||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||||
library_data_t::real_bfloat16, library_data_t::real_float):
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
||||||
oneapi::mkl::bfloat16, float>(
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
||||||
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -2488,10 +2471,9 @@ namespace dpct
|
||||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
||||||
float beta_float =
|
float beta_float =
|
||||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
||||||
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
|
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
|
||||||
a, lda, b, ldb, &beta_float, c, ldc,
|
matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2499,8 +2481,7 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2508,8 +2489,7 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2523,8 +2503,7 @@ namespace dpct
|
||||||
sycl::half alpha_half(alpha_value);
|
sycl::half alpha_half(alpha_value);
|
||||||
sycl::half beta_half(beta_value);
|
sycl::half beta_half(beta_value);
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
|
||||||
batch_size, matrix_info);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -1174,11 +1174,11 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||||
|
|
||||||
queue_ptr qptr;
|
queue_ptr qptr;
|
||||||
int device;
|
int device;
|
||||||
|
|
||||||
inline static int counter{ 0 };
|
inline static int counter{ 0 };
|
||||||
|
|
||||||
struct ggml_sycl_buffer {
|
struct ggml_sycl_buffer {
|
||||||
void * ptr = nullptr;
|
void * ptr = nullptr;
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
|
@ -1218,9 +1218,7 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||||
if (b.ptr == nullptr) {
|
if (b.ptr == nullptr) {
|
||||||
void * ptr;
|
void * ptr;
|
||||||
|
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
|
||||||
CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host(
|
|
||||||
size, *qptr)));
|
|
||||||
if (!ptr) {
|
if (!ptr) {
|
||||||
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1229,8 +1227,7 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||||
*actual_size = size;
|
*actual_size = size;
|
||||||
counter = counter + 1;
|
counter = counter + 1;
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
} else if (b.ptr != nullptr) {
|
||||||
else if (b.ptr != nullptr) {
|
|
||||||
++counter;
|
++counter;
|
||||||
b.size = size;
|
b.size = size;
|
||||||
return b.ptr;
|
return b.ptr;
|
||||||
|
@ -3475,14 +3472,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||||
*main_stream, oneapi::mkl::transpose::trans,
|
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||||
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||||
(const void **)(ptrs_src.get() + 0 * ne23),
|
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||||
dpct::library_data_t::real_half, nb01 / nb00,
|
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
||||||
(const void **)(ptrs_src.get() + 1 * ne23),
|
|
||||||
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
|
||||||
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
|
||||||
cu_compute_type, matrix_info.get())));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue