diff --git a/ggml.c b/ggml.c index 11d71c07a..d316e3d31 100644 --- a/ggml.c +++ b/ggml.c @@ -14415,7 +14415,7 @@ static void ggml_compute_forward_rope_f32( freq_factors = (const float *) src2->data; } } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); } // backward process uses inverse rotation by cos and sin. @@ -14531,6 +14531,7 @@ static void ggml_compute_forward_rope_f32( } } +// TODO: deduplicate f16/f32 code static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -14538,6 +14539,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14590,6 +14592,17 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float * freq_factors = NULL; + if (is_neox) { + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14662,10 +14675,11 @@ static void ggml_compute_forward_rope_f16( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign;