From 4cffe910c3be4f6c7fb1ad8f91cfd109b0f738d1 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 19 Aug 2024 11:36:35 +0800 Subject: [PATCH] add onednn --- ggml/src/CMakeLists.txt | 5 ++- ggml/src/ggml-sycl.cpp | 12 ++++++- ggml/src/ggml-sycl/gemm.hpp | 71 +++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-sycl/gemm.hpp diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 1775ef3cc..5433ebe09 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -549,10 +549,13 @@ if (GGML_SYCL) file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") + find_package(DNNL) + message("-- DNNL found:"${DNNL_FOUND}) + add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) if (WIN32) find_package(IntelSYCL REQUIRED) find_package(MKL REQUIRED) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL DNNL::dnnl) else() if (GGML_SYCL_TARGET STREQUAL "INTEL") set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 94cd4b110..538c7d3d9 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -38,6 +38,7 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" +#include "ggml-sycl/gemm.hpp" bool ggml_sycl_loaded(void); void ggml_sycl_free_data(struct ggml_tensor * tensor); @@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; +#if GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, @@ -2491,6 +2493,10 @@ inline void ggml_sycl_op_mul_mat_sycl( dpct::library_data_t::real_half))); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); +#else + DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), + src0_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); +#endif } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); @@ -2513,13 +2519,17 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; - +#if GGML_SYCL_DNNL SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +#else + DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), + src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); +#endif } (void) dst; (void) src1_ddq_i; diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp new file mode 100644 index 000000000..747d8f3d0 --- /dev/null +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -0,0 +1,71 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GEMM_HPP +#define GGML_SYCL_GEMM_HPP + +#include +#include + +#include "ggml-sycl.h" +#include "dnnl.hpp" +#include "dnnl_sycl.hpp" + +#if GGML_SYCL_DNNL + +class DnnlGemmWrapper { +public: + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + + template + static constexpr dt to_dt() { + if constexpr (std::is_same_v) return dt::f32; + else if constexpr (std::is_same_v) return dt::f16; + else static_assert(0); + } + + static inline void row_gemm(sycl::queue& q, bool a_trans, + bool b_trans, int m, int n, int k, + const void* a, dt at, const void* b, dt bt, void* c, dt ct) + { + // Get the device associated with the queue + sycl::device dev = q.get_device(); + // Get the context associated with the queue + sycl::context ctx = q.get_context(); + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); + const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + dnnl::memory::dims a_dims = { m, k }; + dnnl::memory::dims b_dims = { k, n }; + dnnl::memory::dims c_dims = { m, n }; + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); + + // Create the primitive. + auto matmul_prim = dnnl::matmul(matmul_pd); + // Primitive arguments. + std::unordered_map matmul_args; + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); + + matmul_prim.execute(stream, matmul_args); + } +}; +#endif + +#endif // GGML_SYCL_GEMM_HPP