fix multi-gpu issue on sycl
Signed-off-by: Chen Xi <xi2chen@intel.com>
This commit is contained in:
parent
090fca7a07
commit
cd296feac3
2 changed files with 113 additions and 43 deletions
|
@ -255,7 +255,7 @@ namespace dpct
|
||||||
void set_pitch(size_t pitch) { _pitch = pitch; }
|
void set_pitch(size_t pitch) { _pitch = pitch; }
|
||||||
|
|
||||||
size_t get_x() { return _x; }
|
size_t get_x() { return _x; }
|
||||||
void set_x(size_t x) { _x = x; }
|
void set_x(size_t x) { _x = x; };
|
||||||
|
|
||||||
size_t get_y() { return _y; }
|
size_t get_y() { return _y; }
|
||||||
void set_y(size_t y) { _y = y; }
|
void set_y(size_t y) { _y = y; }
|
||||||
|
@ -588,7 +588,7 @@ namespace dpct
|
||||||
out = prop;
|
out = prop;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// dpct device extension
|
/// dpct device extension
|
||||||
class device_ext : public sycl::device {
|
class device_ext : public sycl::device {
|
||||||
typedef std::mutex mutex_type;
|
typedef std::mutex mutex_type;
|
||||||
|
|
||||||
|
@ -687,119 +687,128 @@ namespace dpct
|
||||||
init_queues();
|
init_queues();
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue &in_order_queue() { return _q_in_order; }
|
sycl::queue &in_order_queue() { return *_q_in_order; }
|
||||||
|
|
||||||
sycl::queue &out_of_order_queue() { return _q_out_of_order; }
|
sycl::queue &out_of_order_queue() { return *_q_out_of_order; }
|
||||||
|
|
||||||
sycl::queue &default_queue() { return in_order_queue(); }
|
sycl::queue &default_queue() { return in_order_queue(); }
|
||||||
|
|
||||||
void queues_wait_and_throw() {
|
void queues_wait_and_throw() {
|
||||||
std::unique_lock<mutex_type> lock(m_mutex);
|
std::unique_lock<mutex_type> lock(m_mutex);
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
for (auto &q : _queues) {
|
for (const auto &q : _queues) {
|
||||||
q.wait_and_throw();
|
q->wait_and_throw();
|
||||||
}
|
}
|
||||||
// Guard the destruct of current_queues to make sure the ref count is
|
// Guard the destruct of current_queues to make sure the ref count is
|
||||||
// safe.
|
// safe.
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue create_queue(bool enable_exception_handler = false) {
|
sycl::queue *create_queue(bool enable_exception_handler = false) {
|
||||||
return create_in_order_queue(enable_exception_handler);
|
return create_in_order_queue(enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue create_queue(sycl::device device,
|
sycl::queue *create_queue(sycl::device device,
|
||||||
bool enable_exception_handler = false) {
|
bool enable_exception_handler = false) {
|
||||||
return create_in_order_queue(device, enable_exception_handler);
|
return create_in_order_queue(device, enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
|
sycl::queue *create_in_order_queue(bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(enable_exception_handler,
|
return create_queue_impl(enable_exception_handler,
|
||||||
sycl::property::queue::in_order());
|
sycl::property::queue::in_order());
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue create_in_order_queue(sycl::device device,
|
sycl::queue *create_in_order_queue(sycl::device device,
|
||||||
bool enable_exception_handler = false) {
|
bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(device, enable_exception_handler,
|
return create_queue_impl(device, enable_exception_handler,
|
||||||
sycl::property::queue::in_order());
|
sycl::property::queue::in_order());
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl::queue create_out_of_order_queue(
|
sycl::queue *create_out_of_order_queue(
|
||||||
bool enable_exception_handler = false) {
|
bool enable_exception_handler = false) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return create_queue_impl(enable_exception_handler);
|
return create_queue_impl(enable_exception_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
void destroy_queue(sycl::queue queue) {
|
void destroy_queue(sycl::queue *&queue) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
_queues.clear();
|
_queues.erase(std::remove_if(_queues.begin(), _queues.end(),
|
||||||
|
[=](const std::shared_ptr<sycl::queue> &q) -> bool
|
||||||
|
{
|
||||||
|
return q.get() == queue;
|
||||||
|
}),
|
||||||
|
_queues.end());
|
||||||
|
queue = nullptr;
|
||||||
}
|
}
|
||||||
void set_saved_queue(sycl::queue q) {
|
void set_saved_queue(sycl::queue *q) {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
_saved_queue = q;
|
_saved_queue = q;
|
||||||
}
|
}
|
||||||
sycl::queue get_saved_queue() const {
|
sycl::queue *get_saved_queue() const {
|
||||||
std::lock_guard<mutex_type> lock(m_mutex);
|
std::lock_guard<mutex_type> lock(m_mutex);
|
||||||
return _saved_queue;
|
return _saved_queue;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void clear_queues() { _queues.clear(); }
|
void clear_queues() {
|
||||||
|
_queues.clear();
|
||||||
|
_q_in_order = _q_out_of_order = _saved_queue = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void init_queues() {
|
void init_queues() {
|
||||||
_q_in_order =
|
_q_in_order =
|
||||||
create_queue_impl(true, sycl::property::queue::in_order());
|
create_queue_impl(true, sycl::property::queue::in_order());
|
||||||
_q_out_of_order = create_queue_impl(true);
|
_q_out_of_order = create_queue_impl(true);
|
||||||
_saved_queue = default_queue();
|
_saved_queue = &default_queue();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Caller should acquire resource \p m_mutex before calling this
|
/// Caller should acquire resource \p m_mutex before calling this
|
||||||
/// function.
|
/// function.
|
||||||
template <class... Properties>
|
template <class... Properties>
|
||||||
sycl::queue create_queue_impl(bool enable_exception_handler,
|
sycl::queue *create_queue_impl(bool enable_exception_handler,
|
||||||
Properties... properties) {
|
Properties... properties) {
|
||||||
sycl::async_handler eh = {};
|
sycl::async_handler eh = {};
|
||||||
if (enable_exception_handler) {
|
if (enable_exception_handler) {
|
||||||
eh = exception_handler;
|
eh = exception_handler;
|
||||||
}
|
}
|
||||||
auto q = sycl::queue(*this, eh,
|
_queues.push_back(std::make_shared<sycl::queue>(
|
||||||
sycl::property_list(
|
*this, eh,
|
||||||
|
sycl::property_list(
|
||||||
#ifdef DPCT_PROFILING_ENABLED
|
#ifdef DPCT_PROFILING_ENABLED
|
||||||
sycl::property::queue::enable_profiling(),
|
sycl::property::queue::enable_profiling(),
|
||||||
#endif
|
#endif
|
||||||
properties...));
|
properties...)));
|
||||||
_queues.push_back(q);
|
|
||||||
|
|
||||||
return _queues.back();
|
return _queues.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class... Properties>
|
template <class... Properties>
|
||||||
sycl::queue create_queue_impl(sycl::device device,
|
sycl::queue *create_queue_impl(sycl::device device,
|
||||||
bool enable_exception_handler,
|
bool enable_exception_handler,
|
||||||
Properties... properties) {
|
Properties... properties) {
|
||||||
sycl::async_handler eh = {};
|
sycl::async_handler eh = {};
|
||||||
if (enable_exception_handler) {
|
if (enable_exception_handler) {
|
||||||
eh = exception_handler;
|
eh = exception_handler;
|
||||||
}
|
}
|
||||||
_queues.push_back(
|
_queues.push_back(std::make_shared<sycl::queue>(
|
||||||
sycl::queue(device, eh,
|
device, eh,
|
||||||
sycl::property_list(
|
sycl::property_list(
|
||||||
#ifdef DPCT_PROFILING_ENABLED
|
#ifdef DPCT_PROFILING_ENABLED
|
||||||
sycl::property::queue::enable_profiling(),
|
sycl::property::queue::enable_profiling(),
|
||||||
#endif
|
#endif
|
||||||
properties...)));
|
properties...)));
|
||||||
|
|
||||||
return _queues.back();
|
return _queues.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_version(int &major, int &minor) const {
|
void get_version(int &major, int &minor) const {
|
||||||
detail::get_version(*this, major, minor);
|
detail::get_version(*this, major, minor);
|
||||||
}
|
}
|
||||||
sycl::queue _q_in_order, _q_out_of_order;
|
sycl::queue *_q_in_order, *_q_out_of_order;
|
||||||
sycl::queue _saved_queue;
|
sycl::queue *_saved_queue;
|
||||||
std::vector<sycl::queue> _queues;
|
std::vector<std::shared_ptr<sycl::queue>> _queues;
|
||||||
mutable mutex_type m_mutex;
|
mutable mutex_type m_mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -855,15 +864,69 @@ namespace dpct
|
||||||
unsigned int get_device_id(const sycl::device &dev)
|
unsigned int get_device_id(const sycl::device &dev)
|
||||||
{
|
{
|
||||||
unsigned int id = 0;
|
unsigned int id = 0;
|
||||||
for (auto dev_item : _devs)
|
for (auto &dev_item : _devs)
|
||||||
{
|
{
|
||||||
if (*dev_item == dev)
|
if (*dev_item == dev)
|
||||||
{
|
{
|
||||||
break;
|
return id;
|
||||||
}
|
}
|
||||||
id++;
|
id++;
|
||||||
}
|
}
|
||||||
return id;
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string get_preferred_gpu_platform_name() {
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
std::string filter = "level-zero";
|
||||||
|
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
|
||||||
|
if (env) {
|
||||||
|
if (std::strstr(env, "level_zero")) {
|
||||||
|
filter = "level-zero";
|
||||||
|
}
|
||||||
|
else if (std::strstr(env, "opencl")) {
|
||||||
|
filter = "opencl";
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw std::runtime_error("invalid device filter: " + std::string(env));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto plaform_list = sycl::platform::get_platforms();
|
||||||
|
|
||||||
|
for (const auto& platform : plaform_list) {
|
||||||
|
auto devices = platform.get_devices();
|
||||||
|
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
|
||||||
|
return d.is_gpu();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (gpu_dev == devices.end()) {
|
||||||
|
// cout << "platform [" << platform_name
|
||||||
|
// << "] does not contain GPU devices, skipping\n";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto platform_name = platform.get_info<sycl::info::platform::name>();
|
||||||
|
std::string platform_name_low_case;
|
||||||
|
platform_name_low_case.resize(platform_name.size());
|
||||||
|
|
||||||
|
std::transform(
|
||||||
|
platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
|
||||||
|
|
||||||
|
if (platform_name_low_case.find(filter) == std::string::npos) {
|
||||||
|
// cout << "platform [" << platform_name
|
||||||
|
// << "] does not match with requested "
|
||||||
|
// << filter << ", skipping\n";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
result = platform_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result.empty())
|
||||||
|
throw std::runtime_error("can not find preferred GPU platform");
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class DeviceSelector>
|
template <class DeviceSelector>
|
||||||
|
@ -915,6 +978,7 @@ namespace dpct
|
||||||
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
static bool compare_backend(std::string &backend1, std::string &backend2) {
|
||||||
return convert_backend_index(backend1) < convert_backend_index(backend2);
|
return convert_backend_index(backend1) < convert_backend_index(backend2);
|
||||||
}
|
}
|
||||||
|
|
||||||
dev_mgr()
|
dev_mgr()
|
||||||
{
|
{
|
||||||
sycl::device default_device =
|
sycl::device default_device =
|
||||||
|
@ -928,12 +992,17 @@ namespace dpct
|
||||||
|
|
||||||
auto Platforms = sycl::platform::get_platforms();
|
auto Platforms = sycl::platform::get_platforms();
|
||||||
// Keep track of the number of devices per backend
|
// Keep track of the number of devices per backend
|
||||||
std::map<sycl::backend, size_t> DeviceNums;
|
std::map<sycl::backend, size_t> DeviceNums;
|
||||||
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
std::map<std::string, std::vector<sycl::device>> backend_devices;
|
||||||
|
auto preferred_platform_name = get_preferred_gpu_platform_name();
|
||||||
|
|
||||||
while (!Platforms.empty()) {
|
while (!Platforms.empty()) {
|
||||||
auto Platform = Platforms.back();
|
auto Platform = Platforms.back();
|
||||||
Platforms.pop_back();
|
Platforms.pop_back();
|
||||||
|
auto platform_name = Platform.get_info<sycl::info::platform::name>();
|
||||||
|
if (platform_name.compare(preferred_platform_name) != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto devices = Platform.get_devices();
|
auto devices = Platform.get_devices();
|
||||||
std::string backend_type = get_device_backend_and_type(devices[0]);
|
std::string backend_type = get_device_backend_and_type(devices[0]);
|
||||||
for (const auto &device : devices) {
|
for (const auto &device : devices) {
|
||||||
|
@ -945,6 +1014,7 @@ namespace dpct
|
||||||
for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
|
for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
|
||||||
keys.push_back(it->first);
|
keys.push_back(it->first);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(keys.begin(), keys.end(), compare_backend);
|
std::sort(keys.begin(), keys.end(), compare_backend);
|
||||||
|
|
||||||
for (auto &key : keys) {
|
for (auto &key : keys) {
|
||||||
|
@ -967,7 +1037,9 @@ namespace dpct
|
||||||
_cpu_device = _devs.size() - 1;
|
_cpu_device = _devs.size() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
}
|
||||||
|
|
||||||
void check_id(unsigned int id) const
|
void check_id(unsigned int id) const
|
||||||
{
|
{
|
||||||
if (id >= _devs.size())
|
if (id >= _devs.size())
|
||||||
|
@ -1056,7 +1128,7 @@ namespace dpct
|
||||||
#error "Only support Windows and Linux."
|
#error "Only support Windows and Linux."
|
||||||
#endif
|
#endif
|
||||||
next_free = mapped_address_space;
|
next_free = mapped_address_space;
|
||||||
}
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using buffer_id_t = int;
|
using buffer_id_t = int;
|
||||||
|
@ -1077,7 +1149,7 @@ namespace dpct
|
||||||
#else
|
#else
|
||||||
#error "Only support Windows and Linux."
|
#error "Only support Windows and Linux."
|
||||||
#endif
|
#endif
|
||||||
}
|
};
|
||||||
|
|
||||||
mem_mgr(const mem_mgr &) = delete;
|
mem_mgr(const mem_mgr &) = delete;
|
||||||
mem_mgr &operator=(const mem_mgr &) = delete;
|
mem_mgr &operator=(const mem_mgr &) = delete;
|
||||||
|
@ -2426,7 +2498,6 @@ namespace dpct
|
||||||
b, ldb, beta, c, ldc, batch_size);
|
b, ldb, beta, c, ldc, batch_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_int8, library_data_t::real_int8,
|
library_data_t::real_int8, library_data_t::real_int8,
|
||||||
library_data_t::real_int32, library_data_t::real_int32):
|
library_data_t::real_int32, library_data_t::real_int32):
|
||||||
|
@ -2459,6 +2530,7 @@ namespace dpct
|
||||||
batch_size);
|
batch_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_half, library_data_t::real_half,
|
library_data_t::real_half, library_data_t::real_half,
|
||||||
library_data_t::real_half, library_data_t::real_float):
|
library_data_t::real_half, library_data_t::real_float):
|
||||||
|
@ -2595,7 +2667,6 @@ namespace dpct
|
||||||
stride_c, batch_size);
|
stride_c, batch_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_int8, library_data_t::real_int8,
|
library_data_t::real_int8, library_data_t::real_int8,
|
||||||
library_data_t::real_int32, library_data_t::real_int32):
|
library_data_t::real_int32, library_data_t::real_int32):
|
||||||
|
@ -2624,6 +2695,7 @@ namespace dpct
|
||||||
beta, c, ldc, stride_c, batch_size);
|
beta, c, ldc, stride_c, batch_size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_half, library_data_t::real_half,
|
library_data_t::real_half, library_data_t::real_half,
|
||||||
library_data_t::real_half, library_data_t::real_float):
|
library_data_t::real_half, library_data_t::real_float):
|
||||||
|
|
|
@ -19154,9 +19154,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
|
for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
|
||||||
ggml_backend_t backend = ggml_backend_sycl_init(i);
|
ggml_backend_t backend = ggml_backend_sycl_init(i);
|
||||||
if (backend == nullptr) {
|
if (backend == nullptr) {
|
||||||
int id_list[GGML_SYCL_MAX_DEVICES];
|
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
|
||||||
ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, id_list[i], i);
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue