diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 73a3399c5..c649e90a1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4058,8 +4058,11 @@ static __device__ void rope_yarn( ) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; - float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; - float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + } // Get n-d magnitude scaling corrected for interpolation if (freq_scale > 1.0f) diff --git a/ggml-metal.metal b/ggml-metal.metal index 6b0194d51..a1eb2d0d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -688,8 +688,11 @@ static void rope_yarn( ) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + float theta = theta_interp; + if (ext_factor != 0.0f) { + ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + } // Get n-d magnitude scaling corrected for interpolation if (freq_scale > 1.0f) diff --git a/ggml.c b/ggml.c index 94a47faa1..85316a3a1 100644 --- a/ggml.c +++ b/ggml.c @@ -12626,8 +12626,11 @@ static void rope_yarn( ) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + } // Get n-d magnitude scaling corrected for interpolation if (freq_scale > 1.0f)