diff --git a/ggml.c b/ggml.c index 3e721677a..009d5b398 100644 --- a/ggml.c +++ b/ggml.c @@ -11043,8 +11043,9 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; - // backward process uses derivative of cos and sin. - // derivative of cos is just cos, derivative of sin is -sin. + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. const float sin_sign = forward ? 1.0f : -1.0f; const int32_t * pos = (const int32_t *) src1->data; @@ -11199,8 +11200,9 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; - // backward process uses derivative of cos and sin. - // derivative of cos is just cos, derivative of sin is -sin. + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. const float sin_sign = forward ? 1.0f : -1.0f; const int32_t * pos = (const int32_t *) src1->data;