SYCL: Refactor SYCL buffer checks in ggml_sycl_cpy_tensor_2d

This commit is contained in:
Akarshan Biswas 2024-12-19 08:34:53 +05:30
parent a20dde36ff
commit 6be041ae10
No known key found for this signature in database
GPG key ID: 52A578A14B32134D

View file

@ -2348,21 +2348,19 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
char * src_ptr;
if (ggml_backend_buffer_is_host(src->buffer)) {
kind = dpct::host_to_device;
GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n");
//GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
src_ptr = (char *) src->data;
// GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
} else if (ggml_backend_buffer_is_sycl(src->buffer) || ggml_backend_buffer_is_sycl_split(src->buffer)) {
if (!ggml_backend_buffer_is_sycl_split(src->buffer)){
// If buffer is a SYCL buffer
GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
kind = dpct::device_to_device;
src_ptr = (char *) src->data;
}
else {
/*
If buffer is a SYCL split buffer
*/
GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
} else if (ggml_backend_buffer_is_sycl(src->buffer)) {
// If buffer is a SYCL buffer
//GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
kind = dpct::device_to_device;
src_ptr = (char *) src->data;
} else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
/*
If buffer is a SYCL split buffer
*/
//GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
kind = dpct::device_to_device;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
@ -2371,7 +2369,6 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
id = get_current_device_id()));
// GGML_SYCL_DEBUG("current device index %d\n", id);
src_ptr = (char *) extra->data_device[id];
}
} else {
// GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
GGML_ABORT("fatal error");