binbcast: use void pointer to prevent intermediate type conversions

This commit is contained in:
Akarshan Biswas 2025-01-31 18:30:29 +05:30
parent 2d72bd94b0
commit 957c11b2cf
No known key found for this signature in database
GPG key ID: 52A578A14B32134D
3 changed files with 28 additions and 35 deletions

View file

@ -508,8 +508,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
template<float (*bin_op)(const float, const float)> template<float (*bin_op)(const float, const float)>
struct bin_bcast_sycl { struct bin_bcast_sycl {
template <typename src0_t, typename src1_t, typename dst_t> template <typename src0_t, typename src1_t, typename dst_t>
void operator()(ggml_backend_sycl_context & ctx, void operator()(const struct ggml_tensor *src0,
const struct ggml_tensor *src0,
const struct ggml_tensor *src1, struct ggml_tensor *dst, const struct ggml_tensor *src1, struct ggml_tensor *dst,
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
queue_ptr stream) { queue_ptr stream) {
@ -643,30 +642,29 @@ struct bin_bcast_sycl {
}); });
} }
} }
GGML_UNUSED(ctx);
} }
}; };
template <class op> template <class op>
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, const void *src0_dd, const void *src1_dd,
float *dst_dd, void *dst_dd,
const queue_ptr &main_stream) { const queue_ptr &main_stream) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd,
(sycl::half *)dst_dd, main_stream); (sycl::half *)dst_dd, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd, (float *)dst_dd,
main_stream); main_stream);
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, op()(src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
main_stream); main_stream);
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, op()(src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
main_stream); main_stream);
} else { } else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,

View file

@ -756,43 +756,39 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx,
ggml_tensor *dst) { ggml_tensor *dst) {
// TODO: remove duplicate variables const void * src0_dd = static_cast<void *>(dst->src[0]->data);
const float * src0_dd = static_cast<float *>(dst->src[0]->data); const void * src1_dd = static_cast<void *>(dst->src[1]->data);
const float * src1_dd = static_cast<float *>(dst->src[1]->data); void * dst_dd = static_cast<void *>(dst->data);
float * dst_dd = static_cast<float *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream(); const dpct::queue_ptr main_stream = ctx.stream();
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
} }
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
// TODO: remove duplicate variables const void * src0_dd = static_cast<void *>(dst->src[0]->data);
const float * src0_dd = static_cast<float *>(dst->src[0]->data); const void * src1_dd = static_cast<void *>(dst->src[1]->data);
const float * src1_dd = static_cast<float *>(dst->src[1]->data); void * dst_dd = static_cast<void *>(dst->data);
float * dst_dd = static_cast<float *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream(); const dpct::queue_ptr main_stream = ctx.stream();
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
} }
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
// TODO: remove duplicate variables const void * src0_dd = static_cast<void *>(dst->src[0]->data);
const float * src0_dd = static_cast<float *>(dst->src[0]->data); const void * src1_dd = static_cast<void *>(dst->src[1]->data);
const float * src1_dd = static_cast<float *>(dst->src[1]->data); void * dst_dd = static_cast<void *>(dst->data);
float * dst_dd = static_cast<float *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream(); const dpct::queue_ptr main_stream = ctx.stream();
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
} }
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
// TODO: remove duplicate variables const void * src0_dd = static_cast<void *>(dst->src[0]->data);
const float * src0_dd = static_cast<float *>(dst->src[0]->data); const void * src1_dd = static_cast<void *>(dst->src[1]->data);
const float * src1_dd = static_cast<float *>(dst->src[1]->data); void * dst_dd = static_cast<void *>(dst->data);
float * dst_dd = static_cast<float *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream(); const dpct::queue_ptr main_stream = ctx.stream();
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
} }

View file

@ -2534,12 +2534,11 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
// TODO: remove duplicate variables const void * src0_d = static_cast<void *>(dst->src[0]->data);
const float * src0_d = static_cast<float *>(dst->src[0]->data); void * dst_d = static_cast<void *>(dst->data);
float * dst_d = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream(); dpct::queue_ptr main_stream = ctx.stream();
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream); ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
} }