Address PR comments to increase readibility

This commit is contained in:
nscipione 2024-12-02 14:58:30 +00:00
parent ffd0a998c7
commit f6e6fc4d47
3 changed files with 34 additions and 24 deletions

View file

@ -1689,13 +1689,14 @@ namespace dpct
auto data_a = get_memory<const Ta>(a); auto data_a = get_memory<const Ta>(a);
auto data_b = get_memory<const Tb>(b); auto data_b = get_memory<const Tb>(b);
auto data_c = get_memory<Tc>(c); auto data_c = get_memory<Tc>(c);
oneapi::mkl::blas::column_major::gemm(
#ifdef GGML_SYCL_NVIDIA #ifdef GGML_SYCL_NVIDIA
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
beta_value, data_c, ldc);
#else #else
q, oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
beta_value, data_c, ldc);
#endif #endif
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc);
} }
template <typename VecT, class BinaryOperation, class = void> template <typename VecT, class BinaryOperation, class = void>
@ -1758,17 +1759,22 @@ namespace dpct
matrix_info->ld_info[2] = ldc; matrix_info->ld_info[2] = ldc;
matrix_info->groupsize_info = batch_size; matrix_info->groupsize_info = batch_size;
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
#ifdef GGML_SYCL_NVIDIA #ifdef GGML_SYCL_NVIDIA
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
&(matrix_info->groupsize_info));
#else #else
q, sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
#endif q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#endif
q.submit([&](sycl::handler &cgh) q.submit([&](sycl::handler &cgh)
{ {
@ -1790,14 +1796,16 @@ namespace dpct
auto data_a = get_memory<const Ta>(a); auto data_a = get_memory<const Ta>(a);
auto data_b = get_memory<const Tb>(b); auto data_b = get_memory<const Tb>(b);
auto data_c = get_memory<Tc>(c); auto data_c = get_memory<Tc>(c);
oneapi::mkl::blas::column_major::gemm_batch(
#ifdef GGML_SYCL_NVIDIA #ifdef GGML_SYCL_NVIDIA
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, oneapi::mkl::blas::column_major::gemm_batch(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
batch_size);
#else #else
q, oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
stride_c, batch_size);
#endif #endif
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
data_c, ldc, stride_c, batch_size);
} }
} // namespace detail } // namespace detail

View file

@ -2561,15 +2561,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
#if !GGML_SYCL_DNNL #if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
# ifdef GGML_SYCL_NVIDIA # ifdef GGML_SYCL_NVIDIA
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
# else # else
*stream, SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
# endif *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
dst_dd_i, ldc))); dst_dd_i, ldc)));
# endif
#else #else
auto dnnl_stream = ctx.stream_dnnl(stream); auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(), DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),

View file

@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
try { try {
// Perform matrix multiplication using oneMKL GEMM // Perform matrix multiplication using oneMKL GEMM
oneapi::mkl::blas::column_major::gemm(
#ifdef GGML_SYCL_NVIDIA #ifdef GGML_SYCL_NVIDIA
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
ne00, src1_d, ldb, beta, dst_d, ne0);
#else #else
*stream, oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
#endif #endif
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d,
ne0);
} }
catch (sycl::exception const& exc) { catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl; std::cerr << exc.what() << std::endl;