From e4b86a12952bc5acec54a4d7a8194191f9384ef9 Mon Sep 17 00:00:00 2001 From: Chen Xi Date: Tue, 23 Jul 2024 06:50:13 +0000 Subject: [PATCH] fix the perf issue of multi-device Signed-off-by: Chen Xi --- ggml/src/ggml-sycl/common.hpp | 2 +- ggml/src/ggml-sycl/dpct/helper.hpp | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 68d41411b..397bd98dd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -267,7 +267,7 @@ struct ggml_backend_sycl_context { queue_ptr stream(int device, int stream) { if (qptrs[device][stream] == nullptr) { - qptrs[device][stream] = &(dpct::get_current_device().default_queue()); + qptrs[device][stream] = &(dpct::get_device(device).default_queue()); } return qptrs[device][stream]; } diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 0f18cff32..a313ca6e5 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -883,6 +883,12 @@ namespace dpct else if (std::strstr(env, "opencl")) { filter = "opencl"; } + else if (std::strstr(env, "cuda")) { + filter = "cuda"; + } + else if (std::strstr(env, "hip")) { + filter = "hip"; + } else { throw std::runtime_error("invalid device filter: " + std::string(env)); } @@ -2053,6 +2059,11 @@ namespace dpct return dev_mgr::instance().current_device(); } + static inline device_ext &get_device(unsigned int id) + { + return dev_mgr::instance().get_device(id); + } + static inline sycl::queue &get_in_order_queue() { return dev_mgr::instance().current_device().in_order_queue(); @@ -2490,6 +2501,7 @@ namespace dpct b, ldb, beta, c, ldc, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): @@ -2522,7 +2534,6 @@ namespace dpct batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): @@ -2669,6 +2680,7 @@ namespace dpct beta, c, ldc, stride_c, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_float, library_data_t::real_float): @@ -2687,7 +2699,6 @@ namespace dpct beta, c, ldc, stride_c, batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float):