ggml : remove GLM rope mode
ggml-ci
This commit is contained in:
parent
e2370c8e41
commit
cbe4f5f7c0
13 changed files with 151 additions and 363 deletions
|
@ -522,8 +522,8 @@ static struct ggml_tensor * forward(
|
||||||
// wk shape [n_embd, n_embd, 1, 1]
|
// wk shape [n_embd, n_embd, 1, 1]
|
||||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||||
// Kcur shape [n_embd/n_head, n_head, N, 1]
|
// Kcur shape [n_embd/n_head, n_head, N, 1]
|
||||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
|
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
|
||||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
|
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0);
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
|
@ -759,8 +759,8 @@ static struct ggml_tensor * forward_batch(
|
||||||
// wk shape [n_embd, n_embd, 1, 1]
|
// wk shape [n_embd, n_embd, 1, 1]
|
||||||
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||||
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
|
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
|
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
|
||||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
|
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0);
|
||||||
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
|
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
|
||||||
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
|
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
|
||||||
|
|
||||||
|
@ -1056,7 +1056,7 @@ static struct ggml_tensor * forward_lora(
|
||||||
model->layers[il].wqb,
|
model->layers[il].wqb,
|
||||||
cur)),
|
cur)),
|
||||||
n_embd/n_head, n_head, N),
|
n_embd/n_head, n_head, N),
|
||||||
KQ_pos, n_rot, 0, 0);
|
KQ_pos, n_rot, 0);
|
||||||
struct ggml_tensor * Kcur = ggml_rope(ctx0,
|
struct ggml_tensor * Kcur = ggml_rope(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
ggml_mul_mat(ctx0,
|
ggml_mul_mat(ctx0,
|
||||||
|
@ -1065,7 +1065,7 @@ static struct ggml_tensor * forward_lora(
|
||||||
model->layers[il].wkb,
|
model->layers[il].wkb,
|
||||||
cur)),
|
cur)),
|
||||||
n_embd/n_head, n_head, N),
|
n_embd/n_head, n_head, N),
|
||||||
KQ_pos, n_rot, 0, 0);
|
KQ_pos, n_rot, 0);
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
|
|
|
@ -564,7 +564,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||||
const int rope_mode = 0;
|
const int rope_mode = 0;
|
||||||
|
|
||||||
return ggml_rope_ext(ctx,
|
return ggml_rope_ext(ctx,
|
||||||
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
|
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx,
|
||||||
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
@ -302,7 +302,7 @@ static struct ggml_tensor * llama_build_train_graphs(
|
||||||
const int rope_mode = 0;
|
const int rope_mode = 0;
|
||||||
|
|
||||||
return ggml_rope_ext(
|
return ggml_rope_ext(
|
||||||
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -100,46 +100,6 @@ static __global__ void rope_neox(
|
||||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void rope_glm_f32(
|
|
||||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
|
||||||
int n_ctx
|
|
||||||
) {
|
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
const int half_n_dims = ncols/4;
|
|
||||||
|
|
||||||
if (col >= half_n_dims) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
|
||||||
const int i = row*ncols + col;
|
|
||||||
const int i2 = row/p_delta_rows;
|
|
||||||
|
|
||||||
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
|
|
||||||
// FIXME: this is likely wrong
|
|
||||||
const int p = pos != nullptr ? pos[i2] : 0;
|
|
||||||
|
|
||||||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
|
||||||
const float sin_theta = sinf(theta);
|
|
||||||
const float cos_theta = cosf(theta);
|
|
||||||
|
|
||||||
const float x0 = x[i + 0];
|
|
||||||
const float x1 = x[i + half_n_dims];
|
|
||||||
|
|
||||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
||||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
|
||||||
|
|
||||||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
|
||||||
const float sin_block_theta = sinf(block_theta);
|
|
||||||
const float cos_block_theta = cosf(block_theta);
|
|
||||||
|
|
||||||
const float x2 = x[i + half_n_dims * 2];
|
|
||||||
const float x3 = x[i + half_n_dims * 3];
|
|
||||||
|
|
||||||
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
|
||||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void rope_cuda(
|
static void rope_cuda(
|
||||||
|
@ -200,17 +160,6 @@ static void rope_neox_cuda(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_glm_f32_cuda(
|
|
||||||
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
|
||||||
float freq_base, int n_ctx, cudaStream_t stream
|
|
||||||
) {
|
|
||||||
GGML_ASSERT(ncols % 4 == 0);
|
|
||||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
|
||||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
|
||||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
|
||||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void rope_cuda_f16(
|
static void rope_cuda_f16(
|
||||||
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
||||||
|
@ -263,7 +212,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
// RoPE alteration for extended context
|
// RoPE alteration for extended context
|
||||||
|
@ -279,7 +228,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int32_t * pos = nullptr;
|
const int32_t * pos = nullptr;
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
pos = (const int32_t *) src1_d;
|
pos = (const int32_t *) src1_d;
|
||||||
|
|
||||||
|
@ -295,10 +243,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_glm) {
|
if (is_neox) {
|
||||||
GGML_ASSERT(false);
|
|
||||||
rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
|
|
||||||
} else if (is_neox) {
|
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_neox_cuda_f32(
|
rope_neox_cuda_f32(
|
||||||
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||||
|
|
|
@ -2302,9 +2302,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
||||||
|
|
|
@ -8928,49 +8928,6 @@ static void rope_neox(
|
||||||
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_glm_f32(
|
|
||||||
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
|
||||||
int n_ctx
|
|
||||||
, const sycl::nd_item<3> &item_ct1) {
|
|
||||||
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
||||||
item_ct1.get_local_id(2);
|
|
||||||
const int half_n_dims = ncols/4;
|
|
||||||
|
|
||||||
if (col >= half_n_dims) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
||||||
item_ct1.get_local_id(1);
|
|
||||||
const int i = row*ncols + col;
|
|
||||||
const int i2 = row/p_delta_rows;
|
|
||||||
|
|
||||||
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
|
|
||||||
// FIXME: this is likely wrong
|
|
||||||
const int p = pos != nullptr ? pos[i2] : 0;
|
|
||||||
|
|
||||||
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
|
|
||||||
const float sin_theta = sycl::sin((float)theta);
|
|
||||||
const float cos_theta = sycl::cos((float)theta);
|
|
||||||
|
|
||||||
const float x0 = x[i + 0];
|
|
||||||
const float x1 = x[i + half_n_dims];
|
|
||||||
|
|
||||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
||||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
|
||||||
|
|
||||||
const float block_theta =
|
|
||||||
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
|
|
||||||
const float sin_block_theta = sycl::sin((float)block_theta);
|
|
||||||
const float cos_block_theta = sycl::cos((float)block_theta);
|
|
||||||
|
|
||||||
const float x2 = x[i + half_n_dims * 2];
|
|
||||||
const float x3 = x[i + half_n_dims * 3];
|
|
||||||
|
|
||||||
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
|
||||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int row = item_ct1.get_group(1);
|
const int row = item_ct1.get_group(1);
|
||||||
|
@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
|
||||||
const int32_t *pos, float freq_scale,
|
|
||||||
int p_delta_rows, float freq_base, int n_ctx,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % 4 == 0);
|
|
||||||
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
|
|
||||||
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
|
|
||||||
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
rope_glm_f32(x, dst, ncols, pos, freq_scale,
|
|
||||||
p_delta_rows, freq_base, n_ctx,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||||
const int nrows, dpct::queue_ptr stream) {
|
const int nrows, dpct::queue_ptr stream) {
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
@ -14066,7 +14007,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
// RoPE alteration for extended context
|
// RoPE alteration for extended context
|
||||||
|
@ -14087,7 +14028,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
pos = (const int32_t *) src1_dd;
|
pos = (const int32_t *) src1_dd;
|
||||||
|
@ -14103,10 +14043,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_glm) {
|
if (is_neox) {
|
||||||
GGML_ASSERT(false);
|
|
||||||
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
|
||||||
} else if (is_neox) {
|
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
rope_neox_sycl(
|
rope_neox_sycl(
|
||||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||||
|
|
|
@ -3823,11 +3823,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
{
|
{
|
||||||
const int mode = ((const int32_t *) dst->op_params)[2];
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
if (is_glm) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
@ -4307,9 +4302,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||||
const float beta_slow = ((float *) dst->op_params)[10];
|
const float beta_slow = ((float *) dst->op_params)[10];
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
GGML_ASSERT(!is_glm);
|
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
@ -6365,9 +6357,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
const int mode = ((const int32_t *) op->op_params)[2];
|
const int mode = ((const int32_t *) op->op_params)[2];
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
return !is_glm;
|
return true;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
|
102
ggml.c
102
ggml.c
|
@ -6249,7 +6249,6 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6277,7 +6276,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_orig_ctx };
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
|
@ -6300,10 +6299,9 @@ struct ggml_tensor * ggml_rope(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode) {
|
||||||
int n_ctx) {
|
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6312,10 +6310,9 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode) {
|
||||||
int n_ctx) {
|
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6326,7 +6323,6 @@ struct ggml_tensor * ggml_rope_ext(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6335,7 +6331,7 @@ struct ggml_tensor * ggml_rope_ext(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6347,7 +6343,6 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6356,7 +6351,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6367,7 +6362,6 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6376,7 +6370,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6387,7 +6381,6 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6396,7 +6389,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6410,7 +6403,6 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -6433,7 +6425,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_orig_ctx };
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
|
@ -14306,7 +14298,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
@ -14347,7 +14339,6 @@ static void ggml_compute_forward_rope_f32(
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
const float * freq_factors = NULL;
|
const float * freq_factors = NULL;
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
@ -14372,42 +14363,12 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const int64_t p = pos[i2];
|
const int64_t p = pos[i2];
|
||||||
|
|
||||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||||
if (!is_glm) { // TODO: cache sin/cos for glm
|
|
||||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
if (ir++ < ir0) continue;
|
if (ir++ < ir0) continue;
|
||||||
if (ir > ir1) break;
|
if (ir > ir1) break;
|
||||||
|
|
||||||
if (is_glm) {
|
|
||||||
float theta_base = MIN(p, n_ctx - 2);
|
|
||||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
||||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
||||||
const float cos_theta = cosf(theta_base);
|
|
||||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
|
||||||
const float cos_block_theta = cosf(block_theta);
|
|
||||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
||||||
|
|
||||||
theta_base *= theta_scale;
|
|
||||||
block_theta *= theta_scale;
|
|
||||||
|
|
||||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
||||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
const float x0 = src[0];
|
|
||||||
const float x1 = src[n_dims/2];
|
|
||||||
const float x2 = src[n_dims];
|
|
||||||
const float x3 = src[n_dims/2*3];
|
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
||||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
||||||
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
|
|
||||||
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const float theta_base = (float)p;
|
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
const float cos_theta = cache[i0 + 0];
|
const float cos_theta = cache[i0 + 0];
|
||||||
|
@ -14450,7 +14411,6 @@ static void ggml_compute_forward_rope_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: deduplicate f16/f32 code
|
// TODO: deduplicate f16/f32 code
|
||||||
|
@ -14472,7 +14432,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
|
@ -14512,7 +14472,6 @@ static void ggml_compute_forward_rope_f16(
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
const float * freq_factors = NULL;
|
const float * freq_factors = NULL;
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
@ -14537,42 +14496,12 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const int64_t p = pos[i2];
|
const int64_t p = pos[i2];
|
||||||
|
|
||||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||||
if (!is_glm) { // TODO: cache sin/cos for glm
|
|
||||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
if (ir++ < ir0) continue;
|
if (ir++ < ir0) continue;
|
||||||
if (ir > ir1) break;
|
if (ir > ir1) break;
|
||||||
|
|
||||||
if (is_glm) {
|
|
||||||
float theta_base = MIN(p, n_ctx - 2);
|
|
||||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
||||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
||||||
const float cos_theta = cosf(theta_base);
|
|
||||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
|
||||||
const float cos_block_theta = cosf(block_theta);
|
|
||||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
||||||
|
|
||||||
theta_base *= theta_scale;
|
|
||||||
block_theta *= theta_scale;
|
|
||||||
|
|
||||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
||||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
|
||||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
|
||||||
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
|
|
||||||
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
|
|
||||||
|
|
||||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
||||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
||||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
|
|
||||||
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const float theta_base = (float)p;
|
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
const float cos_theta = cache[i0 + 0];
|
const float cos_theta = cache[i0 + 0];
|
||||||
|
@ -14615,7 +14544,6 @@ static void ggml_compute_forward_rope_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_rope(
|
static void ggml_compute_forward_rope(
|
||||||
|
@ -18313,7 +18241,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
|
||||||
|
@ -18332,7 +18260,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
src2,
|
src2,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode,
|
mode,
|
||||||
n_ctx,
|
|
||||||
n_orig_ctx,
|
n_orig_ctx,
|
||||||
freq_base,
|
freq_base,
|
||||||
freq_scale,
|
freq_scale,
|
||||||
|
@ -18349,7 +18276,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
|
||||||
|
@ -18368,7 +18295,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
src2,
|
src2,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode,
|
mode,
|
||||||
n_ctx,
|
|
||||||
n_orig_ctx,
|
n_orig_ctx,
|
||||||
freq_base,
|
freq_base,
|
||||||
freq_scale,
|
freq_scale,
|
||||||
|
|
12
ggml.h
12
ggml.h
|
@ -1465,7 +1465,6 @@ extern "C" {
|
||||||
// rotary position embedding
|
// rotary position embedding
|
||||||
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
|
||||||
// if mode & 2 == 1, GPT-NeoX style
|
// if mode & 2 == 1, GPT-NeoX style
|
||||||
// if mode & 4 == 1, ChatGLM style
|
|
||||||
//
|
//
|
||||||
// b is an int32 vector with size a->ne[2], it contains the positions
|
// b is an int32 vector with size a->ne[2], it contains the positions
|
||||||
// c is freq factors (e.g. phi3-128k), (optional)
|
// c is freq factors (e.g. phi3-128k), (optional)
|
||||||
|
@ -1474,8 +1473,7 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode);
|
||||||
int n_ctx);
|
|
||||||
|
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
@ -1483,8 +1481,7 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode);
|
||||||
int n_ctx);
|
|
||||||
|
|
||||||
// custom RoPE
|
// custom RoPE
|
||||||
GGML_API struct ggml_tensor * ggml_rope_ext(
|
GGML_API struct ggml_tensor * ggml_rope_ext(
|
||||||
|
@ -1494,7 +1491,6 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -1511,7 +1507,6 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -1526,7 +1521,6 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -1542,7 +1536,6 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
@ -1565,7 +1558,6 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
|
||||||
int n_orig_ctx,
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
|
|
102
llama.cpp
102
llama.cpp
|
@ -7179,7 +7179,7 @@ struct llm_build_context {
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
0),
|
0),
|
||||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
|
||||||
cb(tmp, "K_shifted", il);
|
cb(tmp, "K_shifted", il);
|
||||||
|
@ -7404,14 +7404,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7535,12 +7535,12 @@ struct llm_build_context {
|
||||||
case MODEL_7B:
|
case MODEL_7B:
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
@ -7647,14 +7647,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7767,13 +7767,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7891,14 +7891,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8044,14 +8044,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8398,14 +8398,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8838,14 +8838,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr,
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr,
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8957,13 +8957,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9069,14 +9069,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9183,14 +9183,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9335,7 +9335,7 @@ struct llm_build_context {
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
@ -9346,7 +9346,7 @@ struct llm_build_context {
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9457,7 +9457,7 @@ struct llm_build_context {
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
@ -9466,7 +9466,7 @@ struct llm_build_context {
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9574,13 +9574,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
@ -9782,14 +9782,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_ext(
|
struct ggml_tensor * Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_ext(
|
struct ggml_tensor * Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9898,14 +9898,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10015,14 +10015,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10145,14 +10145,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10265,7 +10265,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head_k, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
@ -10274,7 +10274,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head_k, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
@ -10385,14 +10385,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10675,14 +10675,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10806,14 +10806,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10920,14 +10920,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -11055,14 +11055,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -11272,7 +11272,7 @@ struct llm_build_context {
|
||||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
q_pe = ggml_rope_ext(
|
q_pe = ggml_rope_ext(
|
||||||
ctx0, q_pe, inp_pos, nullptr,
|
ctx0, q_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
@ -11281,7 +11281,7 @@ struct llm_build_context {
|
||||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
k_pe = ggml_rope_ext(
|
k_pe = ggml_rope_ext(
|
||||||
ctx0, k_pe, inp_pos, nullptr,
|
ctx0, k_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
|
@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
|
||||||
const std::array<int64_t, 4> ne_a;
|
const std::array<int64_t, 4> ne_a;
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int mode;
|
int mode;
|
||||||
int n_ctx;
|
int n_ctx; // used to generate positions
|
||||||
float fs; // freq_scale
|
float fs; // freq_scale
|
||||||
float ef; // ext_factor
|
float ef; // ext_factor
|
||||||
float af; // attn_factor
|
float af; // attn_factor
|
||||||
|
@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
|
||||||
}
|
}
|
||||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
||||||
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
||||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
||||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
hp.n_rot, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
||||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
hp.n_rot, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -1465,7 +1465,7 @@ int main(int argc, const char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||||
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
||||||
|
@ -1505,7 +1505,7 @@ int main(int argc, const char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||||
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||||
|
|
|
@ -162,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||||
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
|
||||||
// 100, 101, 102, ..., 172
|
// 100, 101, 102, ..., 172
|
||||||
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
|
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
|
||||||
// -67, -67, -67, ..., -67
|
// -67, -67, -67, ..., -67
|
||||||
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
||||||
|
|
||||||
// 33, 34, 35, ..., 105
|
// 33, 34, 35, ..., 105
|
||||||
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
|
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
|
||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue