Add back ggml_sycl_set_device to kernels
This commit is contained in:
parent
0ae9a07cf8
commit
7369e54b33
20 changed files with 48 additions and 2 deletions
|
@ -58,6 +58,7 @@ static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||||
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
||||||
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||||
|
|
|
@ -111,6 +111,7 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
||||||
|
|
||||||
|
|
|
@ -237,6 +237,7 @@ inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||||
void * dst_dd = static_cast<void *>(dst->data);
|
void * dst_dd = static_cast<void *>(dst->data);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
||||||
main_stream);
|
main_stream);
|
||||||
|
@ -250,6 +251,7 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||||
void * dst_dd = static_cast<void *>(dst->data);
|
void * dst_dd = static_cast<void *>(dst->data);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
||||||
main_stream);
|
main_stream);
|
||||||
|
@ -263,6 +265,7 @@ inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||||
void * dst_dd = static_cast<void *>(dst->data);
|
void * dst_dd = static_cast<void *>(dst->data);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
||||||
main_stream);
|
main_stream);
|
||||||
|
@ -276,6 +279,7 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
|
||||||
void * dst_dd = static_cast<void *>(dst->data);
|
void * dst_dd = static_cast<void *>(dst->data);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
|
||||||
main_stream);
|
main_stream);
|
||||||
|
@ -288,6 +292,7 @@ inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||||
const void * src0_d = static_cast<void *>(dst->src[0]->data);
|
const void * src0_d = static_cast<void *>(dst->src[0]->data);
|
||||||
void * dst_d = static_cast<void *>(dst->data);
|
void * dst_d = static_cast<void *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
|
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
|
||||||
} catch (const sycl::exception & exc) {
|
} catch (const sycl::exception & exc) {
|
||||||
|
|
|
@ -30,6 +30,7 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
memcpy(&min, dst->op_params, sizeof(float));
|
memcpy(&min, dst->op_params, sizeof(float));
|
||||||
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
|
|
@ -162,6 +162,7 @@ static void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||||
const ggml_tensor *src0 = dst->src[0];
|
const ggml_tensor *src0 = dst->src[0];
|
||||||
const ggml_tensor *src1 = dst->src[1];
|
const ggml_tensor *src1 = dst->src[1];
|
||||||
queue_ptr stream = ctx.stream();
|
queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
const int32_t dim = ((int32_t *)dst->op_params)[0];
|
const int32_t dim = ((int32_t *)dst->op_params)[0];
|
||||||
|
|
||||||
|
|
|
@ -79,6 +79,7 @@ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
|
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
|
@ -37,6 +37,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
|
|
@ -514,6 +514,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -526,6 +527,7 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -538,6 +540,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -551,6 +554,7 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
|
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
|
@ -562,6 +566,7 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
|
@ -573,6 +578,7 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
|
@ -585,6 +591,7 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
|
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -597,6 +604,7 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
|
@ -608,6 +616,7 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -620,6 +629,7 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *d
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -632,6 +642,7 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
|
@ -643,6 +654,7 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -655,6 +667,7 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -669,6 +682,7 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
}
|
}
|
||||||
|
@ -681,6 +695,7 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
}
|
}
|
||||||
|
@ -697,6 +712,7 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
|
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
|
||||||
}
|
}
|
||||||
|
@ -709,6 +725,7 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||||
}
|
}
|
||||||
|
@ -727,6 +744,7 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
|
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
|
||||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
||||||
|
@ -743,6 +761,7 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
pad_f32_sycl(src0_dd, dst_dd,
|
pad_f32_sycl(src0_dd, dst_dd,
|
||||||
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
|
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
|
||||||
|
@ -760,6 +779,7 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
|
||||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||||
|
|
||||||
const dpct::queue_ptr main_stream = ctx.stream();
|
const dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
|
@ -84,6 +84,7 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
|
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
|
||||||
|
@ -113,6 +114,7 @@ template <typename src0_t> static void get_rows_sycl_float(ggml_backend_sycl_con
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
|
@ -3081,7 +3081,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_backend_sycl_get_device_count() {
|
int ggml_backend_sycl_get_device_count() {
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
|
// GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count:\n");
|
||||||
return ggml_sycl_info().device_count;
|
return ggml_sycl_info().device_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,7 @@ void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
const int64_t H = dst->src[0]->ne[1];
|
const int64_t H = dst->src[0]->ne[1];
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(C % H == 0);
|
GGML_ASSERT(C % H == 0);
|
||||||
GGML_ASSERT(C / H == 64 || C / H == 128);
|
GGML_ASSERT(C / H == 64 || C / H == 128);
|
||||||
|
|
|
@ -112,6 +112,7 @@ static void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||||
const int64_t batch = dst->src[1]->ne[3];
|
const int64_t batch = dst->src[1]->ne[3];
|
||||||
const size_t batch_offset = dst->src[1]->nb[3] / 4; // nb is byte offset, src is type float32
|
const size_t batch_offset = dst->src[1]->nb[3] / 4; // nb is byte offset, src is type float32
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
if (dst->type == GGML_TYPE_F16) {
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||||
|
|
|
@ -326,6 +326,7 @@ static void ggml_sycl_op_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||||
} catch (const sycl::exception & exc) {
|
} catch (const sycl::exception & exc) {
|
||||||
|
@ -348,6 +349,7 @@ static void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor*
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
|
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
|
||||||
} catch (const sycl::exception & exc) {
|
} catch (const sycl::exception & exc) {
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
|
@ -368,6 +370,7 @@ static void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||||
} catch (const sycl::exception & exc) {
|
} catch (const sycl::exception & exc) {
|
||||||
|
|
|
@ -17,6 +17,8 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
|
|
||||||
// Get SYCL queue
|
// Get SYCL queue
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
// set device
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
// Dimension checks
|
// Dimension checks
|
||||||
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
||||||
|
|
|
@ -93,6 +93,7 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||||
const int parallel_elements = N * OC * OH * OW;
|
const int parallel_elements = N * OC * OH * OW;
|
||||||
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
|
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
sycl::range<3> block_nums(1, 1, num_blocks);
|
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
|
|
|
@ -236,6 +236,7 @@ static void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||||
rope_corr_dims corr_dims;
|
rope_corr_dims corr_dims;
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
|
|
@ -29,6 +29,7 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -244,7 +244,7 @@ static void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
ggml_sycl_set_device(ctx.device);
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
|
||||||
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
||||||
|
|
|
@ -31,6 +31,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
|
|
||||||
const int64_t ne = ggml_nelements(dst->src[0]);
|
const int64_t ne = ggml_nelements(dst->src[0]);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
@ -48,6 +49,7 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||||
const int64_t ncols = dst->src[0]->ne[0];
|
const int64_t ncols = dst->src[0]->ne[0];
|
||||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||||
dpct::queue_ptr main_stream = ctx.stream();
|
dpct::queue_ptr main_stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||||
float * dst_dd = static_cast<float *>(dst->data);
|
float * dst_dd = static_cast<float *>(dst->data);
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||||
|
|
||||||
// Calculate execution configuration
|
// Calculate execution configuration
|
||||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue