diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e7d260bd4..03e06c3a6 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -7984,26 +7984,24 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; - const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block - // partial sum for each thread float tmp = 0.0f; const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row; + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row * blocks_per_row + i; // x block index + const int ibx = row*blocks_per_row + i; // x block index - const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - const int iqs = - vdr * - (item_ct1.get_local_id(2) - - i * qi_vdr); // x block quant index when casting the quants to int + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int - tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); } // sum up partial sums and write back result @@ -8826,7 +8824,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { } struct rope_corr_dims { - float v[4]; + float v[2]; }; // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn @@ -8850,29 +8848,38 @@ static void rope_yarn( } // rope == RoPE == rotary positional embedding -template -static void rope( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims -, +template +static void rope_norm( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, const sycl::nd_item<3> &item_ct1) { - const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - const int i = row*ncols + col; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + 1]; @@ -8881,25 +8888,25 @@ static void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template -static void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, - const float * freq_factors, const sycl::nd_item<3> &item_ct1) { - const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + +template +static void rope_neox(const T *x, T *dst, int ne0, int n_dims, + const int32_t *pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, + rope_corr_dims corr_dims, float theta_scale, + const float *freq_factors, + const sycl::nd_item<3> &item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - const int ib = col / n_dims; - const int ic = col % n_dims; - if (ib > 0) { - const int i = row*ncols + ib*n_dims + ic; + if (i0 >= n_dims) { + const int i = row*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -8907,19 +8914,14 @@ static void rope_neox( return; } - const int i = row*ncols + ib*n_dims + ic/2; + const int i = row*ne0 + i0/2; const int i2 = row/p_delta_rows; - float cur_rot = inv_ndims * ic - ib; - - const int p = has_pos ? pos[i2] : 0; - const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; - - const float theta_base = - p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor; + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -12375,15 +12377,18 @@ static void clamp_f32_sycl(const float *x, float *dst, const float min, } template -static void rope_sycl(const T *x, T *dst, int ncols, int nrows, +static void rope_norm_sycl(const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, - rope_corr_dims corr_dims, dpct::queue_ptr stream) { - GGML_ASSERT(ncols % 2 == 0); + rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); - const sycl::range<3> block_nums(1, num_blocks_x, nrows); - if (pos == nullptr) { + const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> block_nums(1, n_blocks_x, nr); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { /* DPCT1049:40: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query @@ -12395,8 +12400,8 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope(x, dst, ncols, pos, freq_scale, p_delta_rows, - freq_base, ext_factor, attn_factor, corr_dims, + rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, + ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); }); } else { @@ -12411,70 +12416,46 @@ static void rope_sycl(const T *x, T *dst, int ncols, int nrows, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope(x, dst, ncols, pos, freq_scale, p_delta_rows, - freq_base, ext_factor, attn_factor, corr_dims, + rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, + ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); }); } } template -static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, +static void rope_neox_sycl(const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, dpct::queue_ptr stream) { - GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); - const sycl::range<3> block_nums(1, num_blocks_x, nrows); + const int n_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> block_nums(1, n_blocks_x, nr); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.0f / n_dims; - if (pos == nullptr) { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - if (freq_factors == nullptr) { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, freq_factors, - item_ct1); - }); - } else { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, freq_factors, - item_ct1); - }); - } + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, + item_ct1); + }); } else { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - if (freq_factors == nullptr) { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); - }); - } else { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ncols, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1); - }); - } + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, + item_ct1); + }); } } @@ -13089,12 +13070,9 @@ void *ggml_sycl_host_malloc(size_t size) try { return nullptr; } - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - void * ptr = nullptr; dpct::err0 err = CHECK_TRY_ERROR( - ptr = (void *)sycl::malloc_host(size, *main_stream)); + ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue())); if (err != 0) { // clear the error @@ -13115,9 +13093,7 @@ catch (sycl::exception const &exc) { } void ggml_sycl_host_free(void *ptr) try { - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream))); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue()))); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -14005,8 +13981,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nr = ggml_nrows(src0); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -14023,27 +13998,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const float * freq_factors = nullptr; - const int32_t * pos = nullptr; - if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_dd; - } - const bool is_neox = mode & 2; -#pragma message("TODO: update rope NORM mode to match NEOX mode") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") + const int32_t * pos = (const int32_t *) src1_dd; - if (is_neox) { - pos = (const int32_t *) src1_dd; - - if (src2 != nullptr) { - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; } rope_corr_dims corr_dims; @@ -14053,12 +14014,12 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_sycl( - (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, - ne00, n_dims, nrows, pos, freq_scale, ne01, + ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); } else { @@ -14066,14 +14027,14 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, } } else { if (src0->type == GGML_TYPE_F32) { - rope_sycl( - (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, main_stream + rope_norm_sycl( + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream ); } else if (src0->type == GGML_TYPE_F16) { - rope_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, - nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, main_stream); + rope_norm_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, + n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream); } else { GGML_ASSERT(false); } @@ -17267,7 +17228,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + int dim = op->op_params[0]; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2; + } break; + case GGML_OP_ROPE: + { + return ggml_is_contiguous(op->src[0]); } break; case GGML_OP_DUP: case GGML_OP_NONE: @@ -17287,7 +17253,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: