diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 50f2cf415..34225c2be 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -61,7 +61,7 @@ static __global__ void rope( template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -85,15 +85,13 @@ static __global__ void rope_neox( const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; - float cur_rot = inv_ndims * ic - ib; - const int p = has_pos ? pos[i2] : 0; const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; + const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -174,30 +172,29 @@ static void rope_neox_cuda( const dim3 block_nums(nrows, num_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { if (freq_factors == nullptr) { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } } else { if (freq_factors == nullptr) { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } } diff --git a/ggml-metal.metal b/ggml-metal.metal index b16f2b7e0..0cb85e1a5 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1767,13 +1767,13 @@ kernel void kernel_rope( const int64_t p = pos[i2]; - const float theta_0 = (float)p; + const float theta_base = (float)p; const float inv_ndims = -1.f/n_dims; if (!is_neox) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + const float theta = theta_base * pow(freq_base, inv_ndims*i0); - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -1789,18 +1789,14 @@ kernel void kernel_rope( } else { for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + const float freq_factor = src2 != src0 ? src2[i0] : 1.0f; - const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; + const float theta = theta_base * pow(freq_base, inv_ndims*ic); float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; + rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml.c b/ggml.c index e6e2397b7..7464ae397 100644 --- a/ggml.c +++ b/ggml.c @@ -14358,7 +14358,7 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); @@ -14407,7 +14407,7 @@ static void ggml_compute_forward_rope_f32( const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign; - theta_base *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -14442,29 +14442,22 @@ static void ggml_compute_forward_rope_f32( dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py for (int64_t ic = 0; ic < ne0; ic += 2) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; + const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta ); - sin_theta *= sin_sign; + sin_theta *= sin_sign; theta_base *= theta_scale; - const int64_t i0 = ib*n_dims + ic/2; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -14543,7 +14536,7 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); @@ -14592,7 +14585,7 @@ static void ggml_compute_forward_rope_f16( const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign; - theta_base *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -14623,29 +14616,22 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py for (int64_t ic = 0; ic < ne0; ic += 2) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; + const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta ); - sin_theta *= sin_sign; + sin_theta *= sin_sign; theta_base *= theta_scale; - const int64_t i0 = ib*n_dims + ic/2; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index a8f7373df..7c85ca7ba 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -2670,14 +2670,12 @@ void main() { const uint i = row*p.ncols + ib*p.ndims + ic/2; const uint i2 = row/p.p_delta_rows; - const float cur_rot = p.inv_ndims * ic - ib; - const int pos = data_b[i2]; const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f; const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor; float cos_theta, sin_theta; - rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta); + rope_yarn(theta_base, ic, cos_theta, sin_theta); const float x0 = float(data_a[i + 0]); const float x1 = float(data_a[i + p.ndims/2]);