ggml : fixes (hopefully)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-29 14:50:22 +03:00
parent 9d5605f965
commit b822605abd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 28 additions and 51 deletions

View file

@ -61,7 +61,7 @@ static __global__ void rope(
template<typename T, bool has_pos, bool has_freq_facs> template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox( static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -85,15 +85,13 @@ static __global__ void rope_neox(
const int i = row*ncols + ib*n_dims + ic/2; const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2]; const float x1 = x[i + n_dims/2];
@ -174,30 +172,29 @@ static void rope_neox_cuda(
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) { if (pos == nullptr) {
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, freq_factors
); );
} else { } else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, freq_factors
); );
} }
} else { } else {
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, freq_factors
); );
} else { } else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, freq_factors
); );
} }
} }

View file

@ -1767,13 +1767,13 @@ kernel void kernel_rope(
const int64_t p = pos[i2]; const int64_t p = pos[i2];
const float theta_0 = (float)p; const float theta_base = (float)p;
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/n_dims;
if (!is_neox) { if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -1789,18 +1789,14 @@ kernel void kernel_rope(
} else { } else {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) { if (ic < n_dims) {
const int64_t ib = 0; const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims` const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
const float cur_rot = inv_ndims*ic - ib;
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; const float theta = theta_base * pow(freq_base, inv_ndims*ic);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
const int64_t i0 = ib*n_dims + ic/2;
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

42
ggml.c
View file

@ -14358,7 +14358,7 @@ static void ggml_compute_forward_rope_f32(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@ -14407,7 +14407,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign; const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@ -14442,29 +14442,22 @@ static void ggml_compute_forward_rope_f32(
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
} }
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ic = 0; ic < ne0; ic += 2) { for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) { if (ic < n_dims) {
const int64_t ib = 0; const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims` const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn( rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
sin_theta *= sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -14543,7 +14536,7 @@ static void ggml_compute_forward_rope_f16(
int ir = 0; int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@ -14592,7 +14585,7 @@ static void ggml_compute_forward_rope_f16(
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign; const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@ -14623,29 +14616,22 @@ static void ggml_compute_forward_rope_f16(
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} }
} else { } else {
// TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ic = 0; ic < ne0; ic += 2) { for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) { if (ic < n_dims) {
const int64_t ib = 0; const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims` const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn( rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
sin_theta *= sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

View file

@ -2670,14 +2670,12 @@ void main() {
const uint i = row*p.ncols + ib*p.ndims + ic/2; const uint i = row*p.ncols + ib*p.ndims + ic/2;
const uint i2 = row/p.p_delta_rows; const uint i2 = row/p.p_delta_rows;
const float cur_rot = p.inv_ndims * ic - ib;
const int pos = data_b[i2]; const int pos = data_b[i2];
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f; const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor; const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta); rope_yarn(theta_base, ic, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]); const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.ndims/2]); const float x1 = float(data_a[i + p.ndims/2]);