sycl: add some ops

This commit is contained in:
Zhiyuan Li 2024-11-03 04:55:29 +11:00
parent 2fc42b6a82
commit bee1cec7d2
5 changed files with 611 additions and 2 deletions

View file

@ -1209,6 +1209,10 @@ static __dpct_inline__ float op_add(const float a, const float b) {
return a + b;
}
static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}
@ -1373,6 +1377,50 @@ static void relu_f32(const float * x, float * dst, const int k,
dst[i] = sycl::fmax((float)(x[i]), (float)0);
}
static void sigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
}
static void sqrt_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = sycl::sqrt(x[i]);
}
static void sin_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = sycl::sin(x[i]);
}
static void cos_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = sycl::cos(x[i]);
}
static void hardsigmoid_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@ -1395,6 +1443,55 @@ static void hardswish_f32(const float * x, float * dst, const int k,
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
}
static void exp_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = sycl::exp(x[i]);
}
static void log_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
float xi = x[i];
if (xi <= 0) {
dst[i] = -INFINITY;
} else {
dst[i] = sycl::log(xi);
}
}
static void neg_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = -x[i];
}
static void step_f32(const float * x, float * dst, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = x[i] > 0.0f;
}
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@ -2388,6 +2485,102 @@ static void hardswish_f32_sycl(const float *x, float *dst, const int k,
});
}
static void exp_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
exp_f32(x, dst, k, item_ct1);
});
}
static void log_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
log_f32(x, dst, k, item_ct1);
});
}
static void neg_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
neg_f32(x, dst, k, item_ct1);
});
}
static void step_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
step_f32(x, dst, k, item_ct1);
});
}
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sigmoid_f32(x, dst, k, item_ct1);
});
}
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sqrt_f32(x, dst, k, item_ct1);
});
}
static void sin_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sin_f32(x, dst, k, item_ct1);
});
}
static void cos_f32_sycl(const float *x, float *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cos_f32(x, dst, k, item_ct1);
});
}
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
const float negative_slope,
queue_ptr stream) {
@ -2816,6 +3009,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
}
}
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float);
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<float, 1> shared_data(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1);
float max_val = -INFINITY;
int max_idx = -1;
for (int col = tid; col < ncols; col += 256) {
float val = x[row * ncols + col];
if (val > max_val) {
max_val = val;
max_idx = col;
}
}
shared_data[tid] = max_val;
shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int stride = 256/2; stride > 0; stride >>= 1) {
if (tid < stride) {
float val1 = shared_data[tid];
float val2 = shared_data[tid + stride];
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
if (tid == 0) {
dst[row] = shared_indices[0];
}
});
});
}
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
const int ncols_x, const int nrows_x,
const int rows_per_channel, const int n_past,
@ -2994,6 +3239,14 @@ inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
@ -3105,7 +3358,7 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor
(void) src1_dd;
}
static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
@ -3121,7 +3374,7 @@ static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml
(void) src1_dd;
}
static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
@ -3136,6 +3389,126 @@ static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_t
(void) src1_dd;
}
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
@ -3379,6 +3752,23 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
(void) src1_dd;
}
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne = ggml_nelements(src0);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
@ -3419,6 +3809,25 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten
(void) src1_dd;
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
@ -3914,6 +4323,30 @@ static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * s
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
@ -3962,6 +4395,12 @@ static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
@ -3974,6 +4413,31 @@ static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
@ -4632,6 +5096,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
}
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
}
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
@ -4642,6 +5111,11 @@ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
}
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
}
static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@ -4673,6 +5147,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
ggml_sycl_func_t func;
switch (tensor->op) {
case GGML_OP_ARGMAX:
func = ggml_sycl_argmax;
break;
case GGML_OP_CONV_TRANSPOSE_1D:
func = ggml_sycl_op_conv_transpose_1d;
break;
@ -4686,19 +5163,32 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
func = ggml_sycl_dup;
break;
case GGML_OP_ADD:
case GGML_OP_ADD1: // TODO: more efficient implementation
func = ggml_sycl_add;
break;
case GGML_OP_SUB:
func = ggml_sycl_sub;
break;
case GGML_OP_ACC:
func = ggml_sycl_acc;
break;
case GGML_OP_MUL:
func = ggml_sycl_mul;
break;
case GGML_OP_LOG:
func = ggml_sycl_log;
break;
case GGML_OP_DIV:
func = ggml_sycl_div;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_NEG:
func = ggml_sycl_neg;
break;
case GGML_UNARY_OP_STEP:
func = ggml_sycl_step;
break;
case GGML_UNARY_OP_GELU:
func = ggml_sycl_gelu;
break;
@ -4714,12 +5204,18 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_UNARY_OP_RELU:
func = ggml_sycl_relu;
break;
case GGML_UNARY_OP_SIGMOID:
func = ggml_sycl_sigmoid;
break;
case GGML_UNARY_OP_HARDSIGMOID:
func = ggml_sycl_hardsigmoid;
break;
case GGML_UNARY_OP_HARDSWISH:
func = ggml_sycl_hardswish;
break;
case GGML_UNARY_OP_EXP:
func = ggml_sycl_exp;
break;
default:
return false;
}
@ -4757,12 +5253,24 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
}
func = ggml_sycl_mul_mat_id;
break;
case GGML_OP_OUT_PROD:
func = ggml_sycl_op_out_prod;
break;
case GGML_OP_SCALE:
func = ggml_sycl_scale;
break;
case GGML_OP_SQR:
func = ggml_sycl_sqr;
break;
case GGML_OP_SQRT:
func = ggml_sycl_sqrt;
break;
case GGML_OP_SIN:
func = ggml_sycl_sin;
break;
case GGML_OP_COS:
func = ggml_sycl_cos;
break;
case GGML_OP_CLAMP:
func = ggml_sycl_clamp;
break;
@ -4794,6 +5302,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_POOL_2D:
func = ggml_sycl_pool2d;
break;
case GGML_OP_SUM:
func = ggml_sycl_sum;
break;
case GGML_OP_SUM_ROWS:
func = ggml_sycl_sum_rows;
break;
@ -5128,13 +5639,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -5171,6 +5686,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
}
return true;
} break;
case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
@ -5220,6 +5737,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
} break;
case GGML_OP_DUP:
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_REPEAT:
@ -5228,11 +5746,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_LOG:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return true;
case GGML_OP_CONT:
@ -5246,6 +5770,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
// TODO: add support for the new F32 operations
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:

View file

@ -27,5 +27,6 @@
#include "tsembd.hpp"
#include "im2col.hpp"
#include "wkv6.hpp"
#include "outprod.hpp"
#endif // GGML_SYCL_BACKEND_HPP

View file

@ -0,0 +1,66 @@
// Copyright (C) 2024 Zhiyuan Li
#include <sycl/sycl.hpp>
#include "outprod.hpp"
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
// Get strides
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
// Get SYCL queue
dpct::queue_ptr stream = ctx.stream();
// Dimension checks
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
// Get data pointers
const float* src0_d = (const float*)src0->data;
const float* src1_d = (const float*)src1->data;
float* dst_d = (float*)dst->data;
// GEMM parameters
const float alpha = 1.0f;
const float beta = 0.0f;
// Handle transposition of src1
const bool src1_T = ggml_is_transposed(src1);
const oneapi::mkl::transpose src1_op =
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
try {
// Perform matrix multiplication using oneMKL GEMM
oneapi::mkl::blas::gemm(*stream,
oneapi::mkl::transpose::nontrans, src1_op,
ne0, ne1, ne01,
alpha,
src0_d, ne00,
src1_d, ldb,
beta,
dst_d, ne0);
}
catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl;
GGML_ASSERT(false);
}
}

View file

@ -0,0 +1,11 @@
#ifndef GGML_SYCL_OUTPROD_HPP
#define GGML_SYCL_OUTPROD_HPP
#include "common.hpp"
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst);
#endif // GGML_SYCL_OUTPROD_HPP

View file

@ -25,6 +25,11 @@
#define SYCL_RELU_BLOCK_SIZE 256
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
#define SYCL_HARDSWISH_BLOCK_SIZE 256
#define SYCL_EXP_BLOCK_SIZE 256
#define SYCL_NEG_BLOCK_SIZE 256
#define SYCL_SIGMOID_BLOCK_SIZE 256
#define SYCL_SQRT_BLOCK_SIZE 256
#define SYCL_SIN_BLOCK_SIZE 256
#define SYCL_SQR_BLOCK_SIZE 256
#define SYCL_CPY_BLOCK_SIZE 32
#define SYCL_SCALE_BLOCK_SIZE 256
@ -41,6 +46,7 @@
#define SYCL_ACC_BLOCK_SIZE 256
#define SYCL_IM2COL_BLOCK_SIZE 256
#define SYCL_POOL2D_BLOCK_SIZE 256
#define SYCL_ARGMAX_BLOCK_SIZE 256
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256