add long rope support in ggml cpu backend
This commit is contained in:
parent
9f871298b6
commit
c5569311a4
1 changed files with 11 additions and 1 deletions
12
ggml.c
12
ggml.c
|
@ -14370,6 +14370,15 @@ static void ggml_compute_forward_rope_f32(
|
|||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float* freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (dst->src[2] != NULL) {
|
||||
GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float*) dst->src[2]->data;
|
||||
}
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
|
@ -14446,10 +14455,11 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
sin_theta *= sin_sign;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue