YaRN : avoid NaN if unused betas are zero

This commit is contained in:
Cebtenzzre 2023-09-05 14:14:05 -04:00
parent 826269adc5
commit cf731d5648
3 changed files with 15 additions and 6 deletions

View file

@ -4058,8 +4058,11 @@ static __device__ void rope_yarn(
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)

View file

@ -688,8 +688,11 @@ static void rope_yarn(
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
float theta = theta_interp;
if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)

7
ggml.c
View file

@ -12626,8 +12626,11 @@ static void rope_yarn(
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)