YaRN : correction to GPT-NeoX implementation

ggml-ci
This commit is contained in:
Jared Van Bortel 2023-11-15 17:07:57 -05:00
parent a6fc554e26
commit f824902623
3 changed files with 7 additions and 12 deletions

View file

@ -1277,10 +1277,9 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const int64_t cur_rot = ib * n_dims + ic;
const float theta = theta_0 * pow(freq_base, cur_rot);
const float theta = theta_0 * pow(freq_base, inv_ndims*cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);