diff --git a/ggml.c b/ggml.c index ecbb0db80..169f98b23 100644 --- a/ggml.c +++ b/ggml.c @@ -14370,6 +14370,15 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float* freq_factors = NULL; + if (is_neox) { + if (dst->src[2] != NULL) { + GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); + freq_factors = (const float*) dst->src[2]->data; + } + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14446,10 +14455,11 @@ static void ggml_compute_forward_rope_f32( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign;