Compare commits

...
Sign in to create a new pull request.

11 commits

Author SHA1 Message Date
Aidan
a235b7c532 Vectorize q load 2024-06-17 10:30:40 +01:00
Aidan
604ef6bf15 Store scales in local mem 2024-06-17 10:26:18 +01:00
Aidan
cb3fb42046 Single load for half2 2024-06-17 10:21:16 +01:00
Aidan
4a481556e6 Remove double lines 2024-06-17 10:16:10 +01:00
Joe Todd
ff076b8873
Merge pull request #7920 from ggerganov/codeplay/revert-host-alloc
Revert "use the correct SYCL context for host USM allocations"
2024-06-14 15:10:33 +01:00
Joe Todd
b2c8c831c9
Merge pull request #7919 from ggerganov/codeplay/unify-rope-sycl
sycl-exp : unify rope neox/norm
2024-06-14 15:08:23 +01:00
Joe Todd
ded54b5d9b Replace powf with sycl::pow in ggml-sycl.cpp
Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-14 13:14:33 +01:00
Joe Todd
18133cab40 Revert "use the correct SYCL context for host USM allocations"
Manually reverting:
https://github.com/ggerganov/llama.cpp/pull/7858

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 12:08:27 +01:00
Joe Todd
abd7c7b8c2 Formatting
Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:36:05 +01:00
Joe Todd
0c0f3f0000 [SYCL] Update unsupported ops
Rope is only supported for contiguous input data.

Concat currently only supports dim=2

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:33:34 +01:00
Joe Todd
9b81b57239 [SYCL] unify rope norm/neox
As per: https://github.com/ggerganov/llama.cpp/pull/7634

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:30:43 +01:00

View file

@ -4297,7 +4297,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) { if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63; d = q[j] & 63;
m = q[j + 4] & 63;
} else { } else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
@ -4306,7 +4307,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
template<typename dst_t> template<typename dst_t>
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) { uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
const block_q4_K * x = (const block_q4_K *) vx; const block_q4_K * x = (const block_q4_K *) vx;
const int i = item_ct1.get_group(2); const int i = item_ct1.get_group(2);
@ -4320,19 +4321,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
dst_t * y = yy + i*QK_K + 64*il + n*ir; dst_t * y = yy + i*QK_K + 64*il + n*ir;
const float dall = x[i].dm[0]; const sycl::half2 dm = x[i].dm;
const float dmin = x[i].dm[1]; const float dall = dm[0];
const float dmin = dm[1];
const uint8_t * q = x[i].qs + 32*il + n*ir; if (tid < 12)
scales_local[tid] = x[i].scales[tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
uint8_t sc, m; uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m); get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m; const float d1 = dall * sc;
get_scale_min_k4(is + 1, x[i].scales, sc, m); const float m1 = dmin * m;
const float d2 = dall * sc; const float m2 = dmin * m; get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
sycl::vec<uint8_t, n> q_vec = reinterpret_cast<const sycl::vec<uint8_t, n>*>(x[i].qs + 32*il + n*ir)[0];
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2; y[l +32] = d2 * (q_vec[l] >> 4) - m2;
} }
} }
@ -8826,7 +8834,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
} }
struct rope_corr_dims { struct rope_corr_dims {
float v[4]; float v[2];
}; };
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
@ -8850,29 +8858,38 @@ static void rope_yarn(
} }
// rope == RoPE == rotary positional embedding // rope == RoPE == rotary positional embedding
template<typename T, bool has_pos> template<typename T, bool has_ff>
static void rope( static void rope_norm(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1)); item_ct1.get_local_id(1));
if (col >= ncols) { if (i0 >= ne0) {
return; return;
} }
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
const int i = row*ncols + col;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0; const float theta_base = pos[i2]*sycl::pow(theta_scale, i0/2.0f);
const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + 1]; const float x1 = x[i + 1];
@ -8881,25 +8898,25 @@ static void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta; dst[i + 1] = x0*sin_theta + x1*cos_theta;
} }
template<typename T, bool has_pos, bool has_freq_facs> template <typename T, bool has_ff>
static void rope_neox( static void rope_neox(const T *x, T *dst, int ne0, int n_dims,
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const int32_t *pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, float ext_factor, float attn_factor,
const float * freq_factors, const sycl::nd_item<3> &item_ct1) { rope_corr_dims corr_dims, float theta_scale,
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + const float *freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1)); item_ct1.get_local_id(1));
if (col >= ncols) { if (i0 >= ne0) {
return; return;
} }
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
const int ib = col / n_dims;
const int ic = col % n_dims;
if (ib > 0) { if (i0 >= n_dims) {
const int i = row*ncols + ib*n_dims + ic; const int i = row*ne0 + i0;
dst[i + 0] = x[i + 0]; dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1]; dst[i + 1] = x[i + 1];
@ -8907,19 +8924,14 @@ static void rope_neox(
return; return;
} }
const int i = row*ncols + ib*n_dims + ic/2; const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib; const float theta_base = pos[i2]*sycl::pow(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
const float theta_base =
p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2]; const float x1 = x[i + n_dims/2];
@ -9886,12 +9898,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, item_ct1); dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
}); });
});
} }
} }
@ -12375,15 +12390,18 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
} }
template <typename T> template <typename T>
static void rope_sycl(const T *x, T *dst, int ncols, int nrows, static void rope_norm_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale, int p_delta_rows, const int32_t *pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, float freq_base, float ext_factor, float attn_factor,
rope_corr_dims corr_dims, dpct::queue_ptr stream) { rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows); const sycl::range<3> block_nums(1, n_blocks_x, nr);
if (pos == nullptr) {
const float theta_scale = sycl::pow(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
/* /*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query the limit. To get the device limit, query
@ -12395,8 +12413,8 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows, rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1); item_ct1);
}); });
} else { } else {
@ -12411,70 +12429,46 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for( stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows, rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1); item_ct1);
}); });
} }
} }
template <typename T> template <typename T>
static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, static void rope_neox_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale, const int32_t *pos, float freq_scale,
int p_delta_rows, float freq_base, float ext_factor, int p_delta_rows, float freq_base, float ext_factor,
float attn_factor, rope_corr_dims corr_dims, float attn_factor, rope_corr_dims corr_dims,
const float * freq_factors, dpct::queue_ptr stream) { const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows); const sycl::range<3> block_nums(1, n_blocks_x, nr);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = sycl::pow(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) { dpct::has_capability_or_fail(stream->get_device(),
dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
{sycl::aspect::fp16}); if (freq_factors == nullptr) {
if (freq_factors == nullptr) { stream->parallel_for(
stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) { rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor,
p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
item_ct1); });
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors,
item_ct1);
});
}
} else { } else {
dpct::has_capability_or_fail(stream->get_device(), stream->parallel_for(
{sycl::aspect::fp16}); sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
if (freq_factors == nullptr) { rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
stream->parallel_for( p_delta_rows, ext_factor, attn_factor,
sycl::nd_range<3>(block_nums * block_dims, block_dims), corr_dims, theta_scale, freq_factors,
[=](sycl::nd_item<3> item_ct1) { item_ct1);
rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale, });
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
}
} }
} }
@ -12592,8 +12586,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = sycl::pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = sycl::pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>(); const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
if (n_local_scratch*sizeof(float) < local_mem_size) { if (n_local_scratch*sizeof(float) < local_mem_size) {
@ -13089,12 +13083,9 @@ void *ggml_sycl_host_malloc(size_t size) try {
return nullptr; return nullptr;
} }
ggml_sycl_set_device(g_main_device);
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
void * ptr = nullptr; void * ptr = nullptr;
dpct::err0 err = CHECK_TRY_ERROR( dpct::err0 err = CHECK_TRY_ERROR(
ptr = (void *)sycl::malloc_host(size, *main_stream)); ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
if (err != 0) { if (err != 0) {
// clear the error // clear the error
@ -13115,9 +13106,7 @@ catch (sycl::exception const &exc) {
} }
void ggml_sycl_host_free(void *ptr) try { void ggml_sycl_host_free(void *ptr) try {
ggml_sycl_set_device(g_main_device); SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
} }
catch (sycl::exception const &exc) { catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@ -14005,8 +13994,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2]; const int64_t nr = ggml_nrows(src0);
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
@ -14023,27 +14011,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd;
}
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
#pragma message("TODO: update rope NORM mode to match NEOX mode") const int32_t * pos = (const int32_t *) src1_dd;
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
if (is_neox) { const float * freq_factors = nullptr;
pos = (const int32_t *) src1_dd; if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
} }
rope_corr_dims corr_dims; rope_corr_dims corr_dims;
@ -14053,12 +14027,12 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
if (is_neox) { if (is_neox) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_neox_sycl( rope_neox_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
ne00, n_dims, nrows, pos, freq_scale, ne01, ne00, n_dims, nr, pos, freq_scale, ne01,
freq_base, ext_factor, attn_factor, corr_dims, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, main_stream); freq_factors, main_stream);
} else { } else {
@ -14066,14 +14040,14 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
} }
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_sycl( rope_norm_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream attn_factor, corr_dims, freq_factors, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, rope_norm_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
nrows, pos, freq_scale, ne01, freq_base, ext_factor, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream); attn_factor, corr_dims, freq_factors, main_stream);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -17267,7 +17241,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; int dim = op->op_params[0];
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
} break;
case GGML_OP_ROPE:
{
return ggml_is_contiguous(op->src[0]);
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
case GGML_OP_NONE: case GGML_OP_NONE:
@ -17287,7 +17266,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS: