Add back ggml_sycl_set_device to kernels

This commit is contained in:
Akarshan Biswas 2025-02-03 11:53:22 +05:30
parent 0ae9a07cf8
commit 7369e54b33
No known key found for this signature in database
GPG key ID: 52A578A14B32134D
20 changed files with 48 additions and 2 deletions

View file

@ -58,6 +58,7 @@ static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * d
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);

View file

@ -111,6 +111,7 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);

View file

@ -237,6 +237,7 @@ inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
void * dst_dd = static_cast<void *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
main_stream);
@ -250,6 +251,7 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
void * dst_dd = static_cast<void *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
main_stream);
@ -263,6 +265,7 @@ inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
void * dst_dd = static_cast<void *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
main_stream);
@ -276,6 +279,7 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
void * dst_dd = static_cast<void *>(dst->data);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
main_stream);
@ -288,6 +292,7 @@ inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * d
const void * src0_d = static_cast<void *>(dst->src[0]->data);
void * dst_d = static_cast<void *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
} catch (const sycl::exception & exc) {

View file

@ -30,6 +30,7 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);

View file

@ -162,6 +162,7 @@ static void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor * d
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int32_t dim = ((int32_t *)dst->op_params)[0];

View file

@ -79,6 +79,7 @@ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor
float * dst_d = (float *)dst->data;
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

View file

@ -37,6 +37,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
const int n_past = ((int32_t *) dst->op_params)[0];
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);

View file

@ -514,6 +514,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -526,6 +527,7 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -538,6 +540,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -551,6 +554,7 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@ -562,6 +566,7 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@ -573,6 +578,7 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@ -585,6 +591,7 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -597,6 +604,7 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@ -608,6 +616,7 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -620,6 +629,7 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *d
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -632,6 +642,7 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@ -643,6 +654,7 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -655,6 +667,7 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -669,6 +682,7 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
@ -681,6 +695,7 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
@ -697,6 +712,7 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
}
@ -709,6 +725,7 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
@ -727,6 +744,7 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
@ -743,6 +761,7 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
pad_f32_sycl(src0_dd, dst_dd,
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
@ -760,6 +779,7 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
const dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);

View file

@ -84,6 +84,7 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
@ -113,6 +114,7 @@ template <typename src0_t> static void get_rows_sycl_float(ggml_backend_sycl_con
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });

View file

@ -3081,7 +3081,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
}
int ggml_backend_sycl_get_device_count() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
// GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count:\n");
return ggml_sycl_info().device_count;
}

View file

@ -88,6 +88,7 @@ void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor
const int64_t H = dst->src[0]->ne[1];
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64 || C / H == 128);

View file

@ -112,6 +112,7 @@ static void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * d
const int64_t batch = dst->src[1]->ne[3];
const size_t batch_offset = dst->src[1]->nb[3] / 4; // nb is byte offset, src is type float32
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
if (dst->type == GGML_TYPE_F16) {
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);

View file

@ -326,6 +326,7 @@ static void ggml_sycl_op_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
} catch (const sycl::exception & exc) {
@ -348,6 +349,7 @@ static void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor*
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@ -368,6 +370,7 @@ static void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor *
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
} catch (const sycl::exception & exc) {

View file

@ -17,6 +17,8 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Get SYCL queue
dpct::queue_ptr stream = ctx.stream();
// set device
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
// Dimension checks
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match

View file

@ -93,6 +93,7 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * d
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sycl::range<3> block_nums(1, 1, num_blocks);

View file

@ -236,6 +236,7 @@ static void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
// compute
if (is_neox) {

View file

@ -29,6 +29,7 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
float * dst_dd = static_cast<float *>(dst->data);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
/*

View file

@ -244,7 +244,7 @@ static void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor *
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
ggml_sycl_set_device(ctx.device);
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
dpct::queue_ptr main_stream = ctx.stream();
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {

View file

@ -31,6 +31,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
const int64_t ne = ggml_nelements(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
@ -48,6 +49,7 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);

View file

@ -115,6 +115,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
// Calculate execution configuration
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td