fix the perf issue of multi-device
Signed-off-by: Chen Xi <xi2.chen@intel.com>
This commit is contained in:
parent
6160a76efb
commit
e4b86a1295
2 changed files with 14 additions and 3 deletions
|
@ -267,7 +267,7 @@ struct ggml_backend_sycl_context {
|
|||
|
||||
queue_ptr stream(int device, int stream) {
|
||||
if (qptrs[device][stream] == nullptr) {
|
||||
qptrs[device][stream] = &(dpct::get_current_device().default_queue());
|
||||
qptrs[device][stream] = &(dpct::get_device(device).default_queue());
|
||||
}
|
||||
return qptrs[device][stream];
|
||||
}
|
||||
|
|
|
@ -883,6 +883,12 @@ namespace dpct
|
|||
else if (std::strstr(env, "opencl")) {
|
||||
filter = "opencl";
|
||||
}
|
||||
else if (std::strstr(env, "cuda")) {
|
||||
filter = "cuda";
|
||||
}
|
||||
else if (std::strstr(env, "hip")) {
|
||||
filter = "hip";
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("invalid device filter: " + std::string(env));
|
||||
}
|
||||
|
@ -2053,6 +2059,11 @@ namespace dpct
|
|||
return dev_mgr::instance().current_device();
|
||||
}
|
||||
|
||||
static inline device_ext &get_device(unsigned int id)
|
||||
{
|
||||
return dev_mgr::instance().get_device(id);
|
||||
}
|
||||
|
||||
static inline sycl::queue &get_in_order_queue()
|
||||
{
|
||||
return dev_mgr::instance().current_device().in_order_queue();
|
||||
|
@ -2490,6 +2501,7 @@ namespace dpct
|
|||
b, ldb, beta, c, ldc, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32):
|
||||
|
@ -2522,7 +2534,6 @@ namespace dpct
|
|||
batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float):
|
||||
|
@ -2669,6 +2680,7 @@ namespace dpct
|
|||
beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_float, library_data_t::real_float):
|
||||
|
@ -2687,7 +2699,6 @@ namespace dpct
|
|||
beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue