Formatting

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
nscipione 2025-01-15 15:13:44 +00:00
parent 6b77639258
commit b0f14c5c2a
3 changed files with 70 additions and 100 deletions

View file

@ -352,14 +352,12 @@ struct ggml_backend_sycl_context {
ggml_sycl_pool & host_pool(int device) { ggml_sycl_pool & host_pool(int device) {
if (host_pools[device] == nullptr) { if (host_pools[device] == nullptr) {
host_pools[device] = new_pool_for_host(stream(device,0), device); host_pools[device] = new_pool_for_host(stream(device, 0), device);
} }
return *host_pools[device]; return *host_pools[device];
} }
ggml_sycl_pool & host_pool() { ggml_sycl_pool & host_pool() { return host_pool(device); }
return host_pool(device);
}
}; };
// common device functions // common device functions

View file

@ -82,14 +82,12 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
return device_type.str(); return device_type.str();
} }
template<typename Ts> template <typename Ts> struct matrix_info_t {
struct matrix_info_t
{
oneapi::mkl::transpose transpose_info[2]; oneapi::mkl::transpose transpose_info[2];
Ts value_info[2]; Ts value_info[2];
std::int64_t size_info[3]; std::int64_t size_info[3];
std::int64_t ld_info[3]; std::int64_t ld_info[3];
std::int64_t groupsize_info; std::int64_t groupsize_info;
}; };
namespace dpct namespace dpct
@ -1737,13 +1735,10 @@ namespace dpct
}; };
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
oneapi::mkl::transpose b_trans, int m, int n, int k, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
const void *alpha, const void **a, int lda, int ldb, const void * beta, void ** c, int ldc, int batch_size,
const void **b, int ldb, const void *beta, void **c, matrix_info_t<float> * matrix_info) {
int ldc, int batch_size, matrix_info_t<float>* matrix_info)
{
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
@ -1763,19 +1758,18 @@ namespace dpct
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1, reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
reinterpret_cast<Ts*>(matrix_info->value_info+1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
&(matrix_info->groupsize_info)); reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#else #else
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
matrix_info->ld_info + 1, reinterpret_cast<Ts*>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#endif #endif
} }
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>
@ -2418,15 +2412,11 @@ namespace dpct
/// \param [in] ldc Leading dimension of C. /// \param [in] ldc Leading dimension of C.
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
/// \param [in] scaling_type Data type of the scaling factors. /// \param [in] scaling_type Data type of the scaling factors.
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
oneapi::mkl::transpose b_trans, int m, int n, int k, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
const void *alpha, const void *a[], const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
library_data_t a_type, int lda, const void *b[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
library_data_t b_type, int ldb, const void *beta, matrix_info_t<float> * matrix_info) {
void *c[], library_data_t c_type, int ldc,
int batch_size, library_data_t scaling_type,
matrix_info_t<float>* matrix_info)
{
std::uint64_t key = std::uint64_t key =
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
switch (key) switch (key)
@ -2435,28 +2425,24 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<float, float, float, float>( detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, beta, c, ldc, batch_size, matrix_info);
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double,
library_data_t::real_double, library_data_t::real_double): library_data_t::real_double, library_data_t::real_double):
{ {
detail::gemm_batch_impl<double, double, double, double>( detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, beta, c, ldc, batch_size, matrix_info);
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_half): library_data_t::real_half, library_data_t::real_half):
{ {
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
sycl::half>(q, a_trans, b_trans, m, n, k, alpha, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
a, lda, b, ldb, beta, c, ldc,
batch_size, matrix_info);
break; break;
} }
#ifdef __INTEL_MKL__ #ifdef __INTEL_MKL__
@ -2464,19 +2450,16 @@ namespace dpct
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float): library_data_t::real_bfloat16, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
b, ldb, beta, c, ldc, batch_size, matrix_info);
break; break;
} }
#endif #endif
@ -2488,10 +2471,9 @@ namespace dpct
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q); dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
float beta_float = float beta_float =
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q); dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
float>(q, a_trans, b_trans, m, n, k, &alpha_float, q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
a, lda, b, ldb, &beta_float, c, ldc, matrix_info);
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2499,8 +2481,7 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>( detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2508,8 +2489,7 @@ namespace dpct
library_data_t::real_float, library_data_t::real_float): library_data_t::real_float, library_data_t::real_float):
{ {
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>( detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
batch_size, matrix_info);
break; break;
} }
case detail::get_type_combination_id( case detail::get_type_combination_id(
@ -2523,8 +2503,7 @@ namespace dpct
sycl::half alpha_half(alpha_value); sycl::half alpha_half(alpha_value);
sycl::half beta_half(beta_value); sycl::half beta_half(beta_value);
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
batch_size, matrix_info);
break; break;
} }
default: default:

View file

@ -1174,20 +1174,20 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
}; };
struct ggml_sycl_pool_host : public ggml_sycl_pool { struct ggml_sycl_pool_host : public ggml_sycl_pool {
queue_ptr qptr; queue_ptr qptr;
int device; int device;
inline static int counter{ 0 };
inline static int counter{0};
struct ggml_sycl_buffer { struct ggml_sycl_buffer {
void * ptr = nullptr; void * ptr = nullptr;
size_t size = 0; size_t size = 0;
}; };
// Set arbitrarly to 64 // Set arbitrarly to 64
static constexpr int MAX_POOL_SIZE{64}; static constexpr int MAX_POOL_SIZE{ 64 };
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE); std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
size_t pool_size = 0; size_t pool_size = 0;
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
@ -1205,32 +1205,29 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
} }
void * alloc(size_t size, size_t * actual_size) override { void * alloc(size_t size, size_t * actual_size) override {
if ( counter == MAX_POOL_SIZE){ if (counter == MAX_POOL_SIZE) {
ggml_sycl_buffer b = buffer_pool[0]; ggml_sycl_buffer b = buffer_pool[0];
size_t look_ahead_size = (size_t) (1.05 * size); size_t look_ahead_size = (size_t) (1.05 * size);
void *ptr = b.ptr; void * ptr = b.ptr;
*actual_size = b.size; *actual_size = b.size;
counter = 1; counter = 1;
return ptr; return ptr;
} }
ggml_sycl_buffer& b = buffer_pool[counter]; ggml_sycl_buffer & b = buffer_pool[counter];
if (b.ptr == nullptr) { if (b.ptr == nullptr) {
void * ptr; void * ptr;
SYCL_CHECK( SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_host( if (!ptr) {
size, *qptr))); GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
if (!ptr) { return nullptr;
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); }
return nullptr; pool_size += size;
} *actual_size = size;
pool_size += size; counter = counter + 1;
*actual_size = size; return ptr;
counter = counter + 1; } else if (b.ptr != nullptr) {
return ptr;
}
else if (b.ptr != nullptr) {
++counter; ++counter;
b.size = size; b.size = size;
return b.ptr; return b.ptr;
@ -1241,9 +1238,9 @@ struct ggml_sycl_pool_host : public ggml_sycl_pool {
// if the pool is not completed add the pointer to it in place of the first nullptr found. // if the pool is not completed add the pointer to it in place of the first nullptr found.
// Otherwise do nothing, pointers will be freed once the pool is deallocated. // Otherwise do nothing, pointers will be freed once the pool is deallocated.
for (int i = 0; i < MAX_POOL_SIZE; ++i) { for (int i = 0; i < MAX_POOL_SIZE; ++i) {
ggml_sycl_buffer& b = buffer_pool[i]; ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr == nullptr) { if (b.ptr == nullptr) {
b.ptr = ptr; b.ptr = ptr;
b.size = size; b.size = size;
return; return;
} }
@ -3446,7 +3443,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(),1); ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13); sycl::range<3> block_dims(1, ne12, ne13);
/* /*
@ -3475,14 +3472,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
}); });
} }
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::mkl::transpose::trans, *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **)(ptrs_src.get() + 0 * ne23), (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
dpct::library_data_t::real_half, nb01 / nb00, (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
(const void **)(ptrs_src.get() + 1 * ne23),
dpct::library_data_t::real_half, nb11 / nb10, beta,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type, matrix_info.get())));
} }
} }
catch (sycl::exception const &exc) { catch (sycl::exception const &exc) {