add engine map

This commit is contained in:
luoyu-intel 2024-08-19 07:29:43 +00:00
parent 4dc55156ee
commit c751e65d81
2 changed files with 31 additions and 7 deletions

View file

@ -2495,10 +2495,15 @@ inline void ggml_sycl_op_mul_mat_sycl(
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else #else
auto dnnl_stream = ctx.stream_dnnl(stream); auto dnnl_stream = ctx.stream_dnnl(stream);
#if 0
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>()); src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
#endif
#endif #endif
} }
else { else {

View file

@ -282,26 +282,45 @@ struct ggml_backend_sycl_context {
} }
#if GGML_SYCL_DNNL #if GGML_SYCL_DNNL
dnnl::stream make_stream(sycl::queue& q) { dnnl::engine make_engine(sycl::queue* q) {
// Get the device associated with the queue // Get the device associated with the queue
sycl::device dev = q.get_device(); sycl::device dev = q->get_device();
// Get the context associated with the queue // Get the context associated with the queue
sycl::context ctx = q.get_context(); sycl::context ctx = q->get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); return eng;
return stream;
} }
std::unordered_map<sycl::queue*, dnnl::stream> stream_map; std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
dnnl::stream stream_dnnl(int device, int _stream) { dnnl::stream stream_dnnl(int device, int _stream) {
auto q = stream(device, _stream); auto q = stream(device, _stream);
return stream_dnnl(q); return stream_dnnl(q);
} }
dnnl::engine engine_dnnl(sycl::queue* qptr) {
auto it = engine_map.find(qptr);
if (it == engine_map.end()) {
auto eng = make_engine(qptr);
engine_map[qptr] = eng;
return eng;
}
else
{
return it->second;
}
}
dnnl::stream stream_dnnl(sycl::queue* qptr) { dnnl::stream stream_dnnl(sycl::queue* qptr) {
auto it = stream_map.find(qptr); auto it = stream_map.find(qptr);
if (it == stream_map.end()) { if (it == stream_map.end()) {
stream_map[qptr] = make_stream(*qptr); auto eng = engine_dnnl(qptr);
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
stream_map[qptr] = stream;
return stream;
}
else
{
return it->second;
} }
return it->second;
} }
dnnl::stream stream_dnnl() { dnnl::stream stream_dnnl() {
return stream_dnnl(device, 0); return stream_dnnl(device, 0);