add sycl_f16

This commit is contained in:
luoyu-intel 2024-08-19 06:12:39 +00:00
parent 4cffe910c3
commit 3d0a64f092
3 changed files with 43 additions and 9 deletions

View file

@ -28,6 +28,7 @@
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{ {
"name": "arm64-windows-msvc", "hidden": true, "name": "arm64-windows-msvc", "hidden": true,
@ -60,6 +61,8 @@
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
] ]
} }

View file

@ -2483,7 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
const sycl::half alpha_f16 = 1.0f; const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f; const sycl::half beta_f16 = 0.0f;
#if GGML_SYCL_DNNL #if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::mkl::transpose::trans, *stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@ -2495,7 +2495,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else #else
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>()); src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#endif #endif
} }
else { else {
@ -2519,7 +2521,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
#if GGML_SYCL_DNNL #if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans, *stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,

View file

@ -39,11 +39,39 @@ public:
const void* a, dt at, const void* b, dt bt, void* c, dt ct) const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{ {
// Get the device associated with the queue // Get the device associated with the queue
sycl::device dev = q.get_device(); sycl::device dev = q.get_device();
// Get the context associated with the queue // Get the context associated with the queue
sycl::context ctx = q.get_context(); sycl::context ctx = q.get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
auto const eng = stream.get_engine();
dnnl::memory::dims a_dims = { m, k }; dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n }; dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n }; dnnl::memory::dims c_dims = { m, n };
@ -66,6 +94,7 @@ public:
matmul_prim.execute(stream, matmul_args); matmul_prim.execute(stream, matmul_args);
} }
}; };
#endif #endif
#endif // GGML_SYCL_GEMM_HPP #endif // GGML_SYCL_GEMM_HPP