binbcast: use void pointer to prevent intermediate type conversions
This commit is contained in:
parent
2d72bd94b0
commit
957c11b2cf
3 changed files with 28 additions and 35 deletions
|
@ -508,8 +508,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
|
|||
template<float (*bin_op)(const float, const float)>
|
||||
struct bin_bcast_sycl {
|
||||
template <typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(ggml_backend_sycl_context & ctx,
|
||||
const struct ggml_tensor *src0,
|
||||
void operator()(const struct ggml_tensor *src0,
|
||||
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
||||
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
@ -643,30 +642,29 @@ struct bin_bcast_sycl {
|
|||
});
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class op>
|
||||
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const void *src0_dd, const void *src1_dd,
|
||||
void *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
||||
op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd,
|
||||
(sycl::half *)dst_dd, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
||||
op()(src0, src1, dst, (const sycl::half *)src0_dd, (const float *)src1_dd, (float *)dst_dd,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
||||
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
||||
op()(src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
||||
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
||||
op()(src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
||||
main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
|
|
|
@ -756,43 +756,39 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
|
|||
|
||||
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx,
|
||||
ggml_tensor *dst) {
|
||||
// TODO: remove duplicate variables
|
||||
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
|
||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||
void * dst_dd = static_cast<void *>(dst->data);
|
||||
const dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
// TODO: remove duplicate variables
|
||||
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
|
||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||
void * dst_dd = static_cast<void *>(dst->data);
|
||||
const dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
// TODO: remove duplicate variables
|
||||
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
|
||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||
void * dst_dd = static_cast<void *>(dst->data);
|
||||
const dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
// TODO: remove duplicate variables
|
||||
const float * src0_dd = static_cast<float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<float *>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
const void * src0_dd = static_cast<void *>(dst->src[0]->data);
|
||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||
void * dst_dd = static_cast<void *>(dst->data);
|
||||
const dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -2534,12 +2534,11 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor
|
|||
|
||||
|
||||
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
// TODO: remove duplicate variables
|
||||
const float * src0_d = static_cast<float *>(dst->src[0]->data);
|
||||
float * dst_d = static_cast<float *>(dst->data);
|
||||
const void * src0_d = static_cast<void *>(dst->src[0]->data);
|
||||
void * dst_d = static_cast<void *>(dst->data);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue