SYCL: Refactor SYCL buffer checks in ggml_sycl_cpy_tensor_2d

This commit is contained in:
Akarshan Biswas 2024-12-19 08:34:53 +05:30
parent a20dde36ff
commit 6be041ae10
No known key found for this signature in database
GPG key ID: 52A578A14B32134D

View file

@ -2348,21 +2348,19 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
char * src_ptr; char * src_ptr;
if (ggml_backend_buffer_is_host(src->buffer)) { if (ggml_backend_buffer_is_host(src->buffer)) {
kind = dpct::host_to_device; kind = dpct::host_to_device;
GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n"); //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
src_ptr = (char *) src->data; src_ptr = (char *) src->data;
// GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr); // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
} else if (ggml_backend_buffer_is_sycl(src->buffer) || ggml_backend_buffer_is_sycl_split(src->buffer)) { } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
if (!ggml_backend_buffer_is_sycl_split(src->buffer)){
// If buffer is a SYCL buffer // If buffer is a SYCL buffer
GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__); //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
kind = dpct::device_to_device; kind = dpct::device_to_device;
src_ptr = (char *) src->data; src_ptr = (char *) src->data;
} } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
else {
/* /*
If buffer is a SYCL split buffer If buffer is a SYCL split buffer
*/ */
GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__); //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]); GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
kind = dpct::device_to_device; kind = dpct::device_to_device;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
@ -2371,7 +2369,6 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
id = get_current_device_id())); id = get_current_device_id()));
// GGML_SYCL_DEBUG("current device index %d\n", id); // GGML_SYCL_DEBUG("current device index %d\n", id);
src_ptr = (char *) extra->data_device[id]; src_ptr = (char *) extra->data_device[id];
}
} else { } else {
// GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n"); // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");