Address PR comments to increase readibility
This commit is contained in:
parent
ffd0a998c7
commit
f6e6fc4d47
3 changed files with 34 additions and 24 deletions
|
@ -1689,13 +1689,14 @@ namespace dpct
|
||||||
auto data_a = get_memory<const Ta>(a);
|
auto data_a = get_memory<const Ta>(a);
|
||||||
auto data_b = get_memory<const Tb>(b);
|
auto data_b = get_memory<const Tb>(b);
|
||||||
auto data_c = get_memory<Tc>(c);
|
auto data_c = get_memory<Tc>(c);
|
||||||
oneapi::mkl::blas::column_major::gemm(
|
|
||||||
#ifdef GGML_SYCL_NVIDIA
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
||||||
|
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||||
|
beta_value, data_c, ldc);
|
||||||
#else
|
#else
|
||||||
q,
|
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||||
|
beta_value, data_c, ldc);
|
||||||
#endif
|
#endif
|
||||||
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VecT, class BinaryOperation, class = void>
|
template <typename VecT, class BinaryOperation, class = void>
|
||||||
|
@ -1758,17 +1759,22 @@ namespace dpct
|
||||||
matrix_info->ld_info[2] = ldc;
|
matrix_info->ld_info[2] = ldc;
|
||||||
matrix_info->groupsize_info = batch_size;
|
matrix_info->groupsize_info = batch_size;
|
||||||
|
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
||||||
#ifdef GGML_SYCL_NVIDIA
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||||
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||||
|
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
||||||
|
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||||
|
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
||||||
|
&(matrix_info->groupsize_info));
|
||||||
#else
|
#else
|
||||||
q,
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
#endif
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||||
matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
||||||
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
|
#endif
|
||||||
|
|
||||||
q.submit([&](sycl::handler &cgh)
|
q.submit([&](sycl::handler &cgh)
|
||||||
{
|
{
|
||||||
|
@ -1790,14 +1796,16 @@ namespace dpct
|
||||||
auto data_a = get_memory<const Ta>(a);
|
auto data_a = get_memory<const Ta>(a);
|
||||||
auto data_b = get_memory<const Tb>(b);
|
auto data_b = get_memory<const Tb>(b);
|
||||||
auto data_c = get_memory<Tc>(c);
|
auto data_c = get_memory<Tc>(c);
|
||||||
oneapi::mkl::blas::column_major::gemm_batch(
|
|
||||||
#ifdef GGML_SYCL_NVIDIA
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
||||||
|
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
||||||
|
batch_size);
|
||||||
#else
|
#else
|
||||||
q,
|
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
||||||
|
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
||||||
|
stride_c, batch_size);
|
||||||
#endif
|
#endif
|
||||||
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
||||||
data_c, ldc, stride_c, batch_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
|
@ -2561,15 +2561,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
#if !GGML_SYCL_DNNL
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
|
||||||
# ifdef GGML_SYCL_NVIDIA
|
# ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
|
||||||
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
|
||||||
|
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||||
# else
|
# else
|
||||||
*stream,
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
# endif
|
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
|
||||||
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
||||||
dst_dd_i, ldc)));
|
dst_dd_i, ldc)));
|
||||||
|
# endif
|
||||||
#else
|
#else
|
||||||
auto dnnl_stream = ctx.stream_dnnl(stream);
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
|
|
|
@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Perform matrix multiplication using oneMKL GEMM
|
// Perform matrix multiplication using oneMKL GEMM
|
||||||
oneapi::mkl::blas::column_major::gemm(
|
|
||||||
#ifdef GGML_SYCL_NVIDIA
|
#ifdef GGML_SYCL_NVIDIA
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
||||||
|
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
||||||
|
ne00, src1_d, ldb, beta, dst_d, ne0);
|
||||||
#else
|
#else
|
||||||
*stream,
|
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
||||||
|
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
||||||
#endif
|
#endif
|
||||||
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d,
|
|
||||||
ne0);
|
|
||||||
}
|
}
|
||||||
catch (sycl::exception const& exc) {
|
catch (sycl::exception const& exc) {
|
||||||
std::cerr << exc.what() << std::endl;
|
std::cerr << exc.what() << std::endl;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue