fix the perf issue of multi-device

Signed-off-by: Chen Xi <xi2.chen@intel.com>
This commit is contained in:
Chen Xi 2024-07-23 06:50:13 +00:00
parent 6160a76efb
commit e4b86a1295
2 changed files with 14 additions and 3 deletions

View file

@ -267,7 +267,7 @@ struct ggml_backend_sycl_context {
queue_ptr stream(int device, int stream) {
if (qptrs[device][stream] == nullptr) {
qptrs[device][stream] = &(dpct::get_current_device().default_queue());
qptrs[device][stream] = &(dpct::get_device(device).default_queue());
}
return qptrs[device][stream];
}

View file

@ -883,6 +883,12 @@ namespace dpct
else if (std::strstr(env, "opencl")) {
filter = "opencl";
}
else if (std::strstr(env, "cuda")) {
filter = "cuda";
}
else if (std::strstr(env, "hip")) {
filter = "hip";
}
else {
throw std::runtime_error("invalid device filter: " + std::string(env));
}
@ -2053,6 +2059,11 @@ namespace dpct
return dev_mgr::instance().current_device();
}
static inline device_ext &get_device(unsigned int id)
{
return dev_mgr::instance().get_device(id);
}
static inline sycl::queue &get_in_order_queue()
{
return dev_mgr::instance().current_device().in_order_queue();
@ -2490,6 +2501,7 @@ namespace dpct
b, ldb, beta, c, ldc, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_int32, library_data_t::real_int32):
@ -2522,7 +2534,6 @@ namespace dpct
batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):
@ -2669,6 +2680,7 @@ namespace dpct
beta, c, ldc, stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_int8, library_data_t::real_int8,
library_data_t::real_float, library_data_t::real_float):
@ -2687,7 +2699,6 @@ namespace dpct
beta, c, ldc, stride_c, batch_size);
break;
}
#endif
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_float):