ggml : fix YARN + add tests + add asserts (#7617)

* tests : add rope tests

ggml-ci

* ggml : fixes (hopefully)

ggml-ci

* tests : add non-cont tests

ggml-ci

* cuda : add asserts for rope/norm + fix DS2

ggml-ci

* ggml : assert contiguousness

* tests : reduce RoPE tests

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-29 20:17:31 +03:00 committed by GitHub
parent cce3dcffc5
commit fb76ec31a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 167 additions and 104 deletions

View file

@ -1767,13 +1767,13 @@ kernel void kernel_rope(
const int64_t p = pos[i2];
const float theta_0 = (float)p;
const float theta_base = (float)p;
const float inv_ndims = -1.f/n_dims;
if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -1789,18 +1789,14 @@ kernel void kernel_rope(
} else {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);