add sycl_f16
This commit is contained in:
parent
4cffe910c3
commit
3d0a64f092
3 changed files with 43 additions and 9 deletions
|
@ -28,6 +28,7 @@
|
||||||
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
|
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
|
||||||
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
||||||
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
|
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
|
||||||
|
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "arm64-windows-msvc", "hidden": true,
|
"name": "arm64-windows-msvc", "hidden": true,
|
||||||
|
@ -60,6 +61,8 @@
|
||||||
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
|
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
|
||||||
|
|
||||||
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
|
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
|
||||||
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
|
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
|
||||||
|
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
|
||||||
|
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -2483,7 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const sycl::half alpha_f16 = 1.0f;
|
const sycl::half alpha_f16 = 1.0f;
|
||||||
const sycl::half beta_f16 = 0.0f;
|
const sycl::half beta_f16 = 0.0f;
|
||||||
#if GGML_SYCL_DNNL
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
@ -2495,7 +2495,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
#else
|
#else
|
||||||
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||||
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
||||||
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||||
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -2519,7 +2521,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
#if GGML_SYCL_DNNL
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
|
|
@ -39,11 +39,39 @@ public:
|
||||||
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
{
|
{
|
||||||
// Get the device associated with the queue
|
// Get the device associated with the queue
|
||||||
sycl::device dev = q.get_device();
|
sycl::device dev = q.get_device();
|
||||||
// Get the context associated with the queue
|
// Get the context associated with the queue
|
||||||
sycl::context ctx = q.get_context();
|
sycl::context ctx = q.get_context();
|
||||||
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||||
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
||||||
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
||||||
|
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
||||||
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
// Create the primitive.
|
||||||
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
|
// Primitive arguments.
|
||||||
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
|
||||||
|
matmul_prim.execute(stream, matmul_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
||||||
|
bool b_trans, int m, int n, int k,
|
||||||
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
|
{
|
||||||
|
auto const eng = stream.get_engine();
|
||||||
dnnl::memory::dims a_dims = { m, k };
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
dnnl::memory::dims b_dims = { k, n };
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
dnnl::memory::dims c_dims = { m, n };
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
@ -66,6 +94,7 @@ public:
|
||||||
matmul_prim.execute(stream, matmul_args);
|
matmul_prim.execute(stream, matmul_args);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // GGML_SYCL_GEMM_HPP
|
#endif // GGML_SYCL_GEMM_HPP
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue