From a7e15b03559ed3bc91e1648f225943455cdb0a4b Mon Sep 17 00:00:00 2001 From: nscipione Date: Fri, 29 Nov 2024 11:29:00 +0000 Subject: [PATCH] [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione --- ggml/src/ggml-sycl/CMakeLists.txt | 3 ++- ggml/src/ggml-sycl/dpct/helper.hpp | 21 ++++++++++++++++++--- ggml/src/ggml-sycl/ggml-sycl.cpp | 7 ++++++- ggml/src/ggml-sycl/outprod.cpp | 7 ++++++- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 83f223fd7..3579a311a 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -68,7 +68,8 @@ else() target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") - target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl) + add_compile_definitions(GGML_SYCL_NVIDIA) + target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index c2f28bb49..b92411cc3 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -1690,7 +1690,12 @@ namespace dpct auto data_b = get_memory(b); auto data_c = get_memory(c); oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::backend_selector{q}, +#else + q, +#endif + a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc); } @@ -1755,7 +1760,12 @@ namespace dpct matrix_info->groupsize_info = batch_size; sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::backend_selector{q}, +#else + q, +#endif + matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), matrix_info->ld_info, @@ -1784,7 +1794,12 @@ namespace dpct auto data_b = get_memory(b); auto data_c = get_memory(c); oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::backend_selector{q}, +#else + q, +#endif + a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, batch_size); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 808f74fa0..0f97bc7ef 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2562,7 +2562,12 @@ inline void ggml_sycl_op_mul_mat_sycl( const float beta = 0.0f; #if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::backend_selector{*stream}, +#else + *stream, +#endif + oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index e61cdc2ca..7a2bfd445 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -40,7 +40,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr try { // Perform matrix multiplication using oneMKL GEMM - oneapi::mkl::blas::column_major::gemm(*stream, + oneapi::mkl::blas::column_major::gemm( +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::backend_selector{*stream}, +#else + *stream, +#endif oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,