ggml : support freq_factors for f16 rope (CPU)
ggml-ci
This commit is contained in:
parent
ce89cd5573
commit
092549b110
1 changed files with 16 additions and 2 deletions
18
ggml.c
18
ggml.c
|
@ -14415,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
|
@ -14531,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: deduplicate f16/f32 code
|
||||
static void ggml_compute_forward_rope_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
|
@ -14538,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
|
|||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
|
@ -14590,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
|
|||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
|
@ -14662,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
|
|||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
sin_theta *= sin_sign;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue