remove duplicate buft initialization

This commit is contained in:
Meng, Hengyu 2024-06-05 16:57:17 +08:00
parent abe11feab6
commit 9c5476ead4

View file

@ -12606,6 +12606,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
}; };
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
if (device>=ggml_sycl_info().device_count or device<0) { if (device>=ggml_sycl_info().device_count or device<0) {
@ -13000,6 +13003,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
}; };
GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
ggml_check_sycl(); ggml_check_sycl();
// FIXME: this is not thread safe // FIXME: this is not thread safe
@ -13106,7 +13112,7 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) { GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
return ggml_backend_sycl_buffer_type(sycl_ctx); return ggml_backend_sycl_buffer_type(sycl_ctx->device);
} }
GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
@ -13114,8 +13120,9 @@ GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
const void *data, size_t offset, const void *data, size_t offset,
size_t size) try { size_t size) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
(char *)tensor->data + offset, data, size).wait())); (char *)tensor->data + offset, data, size).wait()));
@ -13131,8 +13138,9 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
void *data, size_t offset, void *data, size_t offset,
size_t size) try { size_t size) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
data, (const char *)tensor->data + offset, size).wait())); data, (const char *)tensor->data + offset, size).wait()));
@ -13147,7 +13155,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
const ggml_tensor *src, const ggml_tensor *src,
ggml_tensor *dst) try { ggml_tensor *dst) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) { if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
/* /*
DPCT1009:215: SYCL uses exceptions to report errors and does not use the DPCT1009:215: SYCL uses exceptions to report errors and does not use the
error codes. The original code was commented out and a warning string error codes. The original code was commented out and a warning string
@ -13191,10 +13199,10 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
continue; continue;
} }
#ifndef NDEBUG #ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx)); assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) { for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) { if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx)); assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
} }
} }
#endif #endif