diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 4bf875c9a..ae2778784 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -508,8 +508,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t template struct bin_bcast_sycl { template - void operator()(ggml_backend_sycl_context & ctx, - const struct ggml_tensor *src0, + void operator()(const struct ggml_tensor *src0, const struct ggml_tensor *src1, struct ggml_tensor *dst, const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, queue_ptr stream) { @@ -643,30 +642,29 @@ struct bin_bcast_sycl { }); } } - GGML_UNUSED(ctx); } }; template -inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, +inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, + const void *src0_dd, const void *src1_dd, + void *dst_dd, const queue_ptr &main_stream) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, + op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd, (sycl::half *)dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, + op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd, (float *)dst_dd, main_stream); } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, + op()(src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, main_stream); } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, + op()(src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, main_stream); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 6d68ea077..185bf11e7 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -756,43 +756,39 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - // TODO: remove duplicate variables - const float * src0_dd = static_cast(dst->src[0]->data); - const float * src1_dd = static_cast(dst->src[1]->data); - float * dst_dd = static_cast(dst->data); + const void * src0_dd = static_cast(dst->src[0]->data); + const void * src1_dd = static_cast(dst->src[1]->data); + void * dst_dd = static_cast(dst->data); const dpct::queue_ptr main_stream = ctx.stream(); - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - // TODO: remove duplicate variables - const float * src0_dd = static_cast(dst->src[0]->data); - const float * src1_dd = static_cast(dst->src[1]->data); - float * dst_dd = static_cast(dst->data); + const void * src0_dd = static_cast(dst->src[0]->data); + const void * src1_dd = static_cast(dst->src[1]->data); + void * dst_dd = static_cast(dst->data); const dpct::queue_ptr main_stream = ctx.stream(); - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - // TODO: remove duplicate variables - const float * src0_dd = static_cast(dst->src[0]->data); - const float * src1_dd = static_cast(dst->src[1]->data); - float * dst_dd = static_cast(dst->data); + const void * src0_dd = static_cast(dst->src[0]->data); + const void * src1_dd = static_cast(dst->src[1]->data); + void * dst_dd = static_cast(dst->data); const dpct::queue_ptr main_stream = ctx.stream(); - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - // TODO: remove duplicate variables - const float * src0_dd = static_cast(dst->src[0]->data); - const float * src1_dd = static_cast(dst->src[1]->data); - float * dst_dd = static_cast(dst->data); + const void * src0_dd = static_cast(dst->src[0]->data); + const void * src1_dd = static_cast(dst->src[1]->data); + void * dst_dd = static_cast(dst->data); const dpct::queue_ptr main_stream = ctx.stream(); - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 91c244579..0c49cb54f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2534,12 +2534,11 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - // TODO: remove duplicate variables - const float * src0_d = static_cast(dst->src[0]->data); - float * dst_d = static_cast(dst->data); + const void * src0_d = static_cast(dst->src[0]->data); + void * dst_d = static_cast(dst->data); dpct::queue_ptr main_stream = ctx.stream(); - ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream); + ggml_sycl_op_bin_bcast>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream); }