ggml : fix YARN + add tests + add asserts (#7617)
* tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci
This commit is contained in:
parent
cce3dcffc5
commit
fb76ec31a9
12 changed files with 167 additions and 104 deletions
|
@ -1767,13 +1767,13 @@ kernel void kernel_rope(
|
|||
|
||||
const int64_t p = pos[i2];
|
||||
|
||||
const float theta_0 = (float)p;
|
||||
const float theta_base = (float)p;
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
@ -1789,18 +1789,14 @@ kernel void kernel_rope(
|
|||
} else {
|
||||
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t ib = 0;
|
||||
const int64_t i0 = ic/2;
|
||||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
const float cur_rot = inv_ndims*ic - ib;
|
||||
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
||||
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue