YaRN : avoid NaN if unused betas are zero
This commit is contained in:
parent
826269adc5
commit
cf731d5648
3 changed files with 15 additions and 6 deletions
|
@ -4058,8 +4058,11 @@ static __device__ void rope_yarn(
|
|||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
|
|
|
@ -688,8 +688,11 @@ static void rope_yarn(
|
|||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
|
|
7
ggml.c
7
ggml.c
|
@ -12626,8 +12626,11 @@ static void rope_yarn(
|
|||
) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue