diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 0d884f89a..8f92b76a6 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -2495,10 +2495,15 @@ inline void ggml_sycl_op_mul_mat_sycl( to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else auto dnnl_stream = ctx.stream_dnnl(stream); +#if 0 DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#else + DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), + src0_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); +#endif #endif } else { diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index d21104876..8ceac6533 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -282,26 +282,45 @@ struct ggml_backend_sycl_context { } #if GGML_SYCL_DNNL - dnnl::stream make_stream(sycl::queue& q) { + dnnl::engine make_engine(sycl::queue* q) { // Get the device associated with the queue - sycl::device dev = q.get_device(); + sycl::device dev = q->get_device(); // Get the context associated with the queue - sycl::context ctx = q.get_context(); + sycl::context ctx = q->get_context(); const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); - dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); - return stream; + return eng; } + std::unordered_map stream_map; + std::unordered_map engine_map; dnnl::stream stream_dnnl(int device, int _stream) { auto q = stream(device, _stream); return stream_dnnl(q); } + dnnl::engine engine_dnnl(sycl::queue* qptr) { + auto it = engine_map.find(qptr); + if (it == engine_map.end()) { + auto eng = make_engine(qptr); + engine_map[qptr] = eng; + return eng; + } + else + { + return it->second; + } + } dnnl::stream stream_dnnl(sycl::queue* qptr) { auto it = stream_map.find(qptr); if (it == stream_map.end()) { - stream_map[qptr] = make_stream(*qptr); + auto eng = engine_dnnl(qptr); + auto stream = dnnl::sycl_interop::make_stream(eng, *qptr); + stream_map[qptr] = stream; + return stream; + } + else + { + return it->second; } - return it->second; } dnnl::stream stream_dnnl() { return stream_dnnl(device, 0);