YaRN : correction to GPT-NeoX implementation

ggml-ci
This commit is contained in:
Jared Van Bortel 2023-11-15 17:07:57 -05:00
parent a6fc554e26
commit f824902623
3 changed files with 7 additions and 12 deletions

View file

@ -4596,14 +4596,12 @@ static __global__ void rope_neox(
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
const float cur_rot = -float(col)/ncols;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, cur_rot);
const float theta_base = p*powf(freq_base, -float(col)/ncols);
// rotation amount is `ib * ncols + col`, but ib is assumed to be zero
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + ncols/2];