Compare commits

...
Sign in to create a new pull request.

6 commits

Author SHA1 Message Date
Joe Todd
ce6e28cc23
Update ggml-sycl.cpp 2024-06-18 09:57:14 +01:00
Joe Todd
6f6612570e Revert "Minor arithmetic improvement to mmvq wrapper kernel (#7172)"
This reverts commit 8c570c9496.
2024-06-14 22:22:57 +01:00
Joe Todd
18133cab40 Revert "use the correct SYCL context for host USM allocations"
Manually reverting:
https://github.com/ggerganov/llama.cpp/pull/7858

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 12:08:27 +01:00
Joe Todd
abd7c7b8c2 Formatting
Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:36:05 +01:00
Joe Todd
0c0f3f0000 [SYCL] Update unsupported ops
Rope is only supported for contiguous input data.

Concat currently only supports dim=2

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:33:34 +01:00
Joe Todd
9b81b57239 [SYCL] unify rope norm/neox
As per: https://github.com/ggerganov/llama.cpp/pull/7634

Signed-off-by: Joe Todd <joe.todd@codeplay.com>
2024-06-13 10:30:43 +01:00

View file

@ -7984,15 +7984,13 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
@ -8000,8 +7998,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
const int iqs =
vdr *
(item_ct1.get_local_id(2) -
i * qi_vdr); // x block quant index when casting the quants to int
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
}
@ -8826,7 +8824,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
}
struct rope_corr_dims {
float v[4];
float v[2];
};
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
@ -8850,29 +8848,38 @@ static void rope_yarn(
}
// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos>
static void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
,
template<typename T, bool has_ff>
static void rope_norm(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (col >= ncols) {
if (i0 >= ne0) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int i = row*ncols + col;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + 1];
@ -8881,25 +8888,25 @@ static void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_pos, bool has_freq_facs>
static void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
template <typename T, bool has_ff>
static void rope_neox(const T *x, T *dst, int ne0, int n_dims,
const int32_t *pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor,
rope_corr_dims corr_dims, float theta_scale,
const float *freq_factors,
const sycl::nd_item<3> &item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (col >= ncols) {
if (i0 >= ne0) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int ib = col / n_dims;
const int ic = col % n_dims;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
@ -8907,19 +8914,14 @@ static void rope_neox(
return;
}
const int i = row*ncols + ib*n_dims + ic/2;
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
const float theta_base =
p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
@ -12375,15 +12377,18 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min,
}
template <typename T>
static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
static void rope_norm_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor,
rope_corr_dims corr_dims, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0);
rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
if (pos == nullptr) {
const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, n_blocks_x, nr);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
@ -12395,8 +12400,8 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope<T, false>(x, dst, ncols, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims,
rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
} else {
@ -12411,71 +12416,47 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope<T, true>(x, dst, ncols, pos, freq_scale, p_delta_rows,
freq_base, ext_factor, attn_factor, corr_dims,
rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
item_ct1);
});
}
}
template <typename T>
static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
static void rope_neox_sycl(const T *x, T *dst, int ne0, int n_dims, int nr,
const int32_t *pos, float freq_scale,
int p_delta_rows, float freq_base, float ext_factor,
float attn_factor, rope_corr_dims corr_dims,
const float * freq_factors, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, num_blocks_x, nrows);
const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, n_blocks_x, nr);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors,
corr_dims, theta_scale, freq_factors,
item_ct1);
});
}
} else {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
if (freq_factors == nullptr) {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
} else {
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
p_delta_rows, ext_factor, attn_factor,
corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
});
}
}
}
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
@ -13089,12 +13070,9 @@ void *ggml_sycl_host_malloc(size_t size) try {
return nullptr;
}
ggml_sycl_set_device(g_main_device);
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
void * ptr = nullptr;
dpct::err0 err = CHECK_TRY_ERROR(
ptr = (void *)sycl::malloc_host(size, *main_stream));
ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
if (err != 0) {
// clear the error
@ -13115,9 +13093,7 @@ catch (sycl::exception const &exc) {
}
void ggml_sycl_host_free(void *ptr) try {
ggml_sycl_set_device(g_main_device);
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@ -14005,8 +13981,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
const int64_t nr = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@ -14023,28 +13998,14 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd;
}
const bool is_neox = mode & 2;
#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
if (is_neox) {
pos = (const int32_t *) src1_dd;
const int32_t * pos = (const int32_t *) src1_dd;
const float * freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
@ -14053,12 +14014,12 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
ne00, n_dims, nrows, pos, freq_scale, ne01,
ne00, n_dims, nr, pos, freq_scale, ne01,
freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, main_stream);
} else {
@ -14066,14 +14027,14 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream
rope_norm_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream);
rope_norm_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00,
n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream);
} else {
GGML_ASSERT(false);
}
@ -17267,7 +17228,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
int dim = op->op_params[0];
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
} break;
case GGML_OP_ROPE:
{
return ggml_is_contiguous(op->src[0]);
} break;
case GGML_OP_DUP:
case GGML_OP_NONE:
@ -17287,7 +17253,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS: