From 3d0a64f092eac7b37a6edd38f478abfd378be6fa Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 19 Aug 2024 06:12:39 +0000 Subject: [PATCH] add sycl_f16 --- CMakePresets.json | 5 ++++- ggml/src/ggml-sycl.cpp | 8 +++++--- ggml/src/ggml-sycl/gemm.hpp | 39 ++++++++++++++++++++++++++++++++----- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index bdad38952..ce627b4d3 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -28,6 +28,7 @@ { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, { "name": "arm64-windows-msvc", "hidden": true, @@ -60,6 +61,8 @@ { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, - { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] } ] } diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 538c7d3d9..2ac3e4543 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -2483,7 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; -#if GGML_SYCL_DNNL +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, @@ -2495,7 +2495,9 @@ inline void ggml_sycl_op_mul_mat_sycl( to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), - src0_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); + src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); #endif } else { @@ -2519,7 +2521,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; -#if GGML_SYCL_DNNL +#if !GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 747d8f3d0..66c45cec8 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -39,11 +39,39 @@ public: const void* a, dt at, const void* b, dt bt, void* c, dt ct) { // Get the device associated with the queue - sycl::device dev = q.get_device(); - // Get the context associated with the queue - sycl::context ctx = q.get_context(); - const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); - const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + sycl::device dev = q.get_device(); + // Get the context associated with the queue + sycl::context ctx = q.get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } + + + static inline void row_gemm(const dnnl::stream& stream, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + auto const eng = stream.get_engine(); dnnl::memory::dims a_dims = { m, k }; dnnl::memory::dims b_dims = { k, n }; dnnl::memory::dims c_dims = { m, n }; @@ -66,6 +94,7 @@ public: matmul_prim.execute(stream, matmul_args); } }; + #endif #endif // GGML_SYCL_GEMM_HPP