ggml : unify rope norm/neox (CPU)

This commit is contained in:
Georgi Gerganov 2024-05-30 11:28:26 +03:00
parent 9422c5e34b
commit 2fd31fe188
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 81 additions and 143 deletions

212
ggml.c
View file

@ -6257,8 +6257,6 @@ static struct ggml_tensor * ggml_rope_impl(
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow, float beta_slow,
float xpos_base,
bool xpos_down,
bool inplace) { bool inplace) {
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
@ -6279,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; int32_t params[11] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
@ -6307,7 +6303,7 @@ struct ggml_tensor * ggml_rope(
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
); );
} }
@ -6319,7 +6315,7 @@ struct ggml_tensor * ggml_rope_inplace(
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
); );
} }
@ -6340,7 +6336,7 @@ struct ggml_tensor * ggml_rope_ext(
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ext_factor, attn_factor, beta_fast, beta_slow, false
); );
} }
@ -6361,7 +6357,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ext_factor, attn_factor, beta_fast, beta_slow, true
); );
} }
@ -6381,7 +6377,7 @@ struct ggml_tensor * ggml_rope_custom(
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ext_factor, attn_factor, beta_fast, beta_slow, false
); );
} }
@ -6401,20 +6397,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_slow) { float beta_slow) {
return ggml_rope_impl( return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ext_factor, attn_factor, beta_fast, beta_slow, true
); );
} }
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back // ggml_rope_back
struct ggml_tensor * ggml_rope_back( struct ggml_tensor * ggml_rope_back(
@ -6431,9 +6417,7 @@ struct ggml_tensor * ggml_rope_back(
float ext_factor, float ext_factor,
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow, float beta_slow) {
float xpos_base,
bool xpos_down) {
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]); GGML_ASSERT(a->ne[2] == b->ne[0]);
@ -6449,15 +6433,13 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; int32_t params[11] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK; result->op = GGML_OP_ROPE_BACK;
@ -14282,13 +14264,15 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, fl
} }
static void ggml_rope_cache_init( static void ggml_rope_cache_init(
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float theta_base, float freq_scale, float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale float * cache, float sin_sign, float theta_scale
) { ) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta = theta_base; float theta = theta_base;
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
rope_yarn( rope_yarn(
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
); );
cache[i0 + 1] *= sin_sign; cache[i0 + 1] *= sin_sign;
@ -14321,10 +14305,6 @@ static void ggml_compute_forward_rope_f32(
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -14337,8 +14317,6 @@ static void ggml_compute_forward_rope_f32(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
@ -14396,18 +14374,16 @@ static void ggml_compute_forward_rope_f32(
const int64_t p = pos[i2]; const int64_t p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox if (!is_glm) { // TODO: cache sin/cos for glm
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
} }
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta_base = (float)p;
if (is_glm) { if (is_glm) {
theta_base = MIN(p, n_ctx - 2); float theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base); const float cos_theta = cosf(theta_base);
@ -14431,59 +14407,48 @@ static void ggml_compute_forward_rope_f32(
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
} }
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else { } else {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py const float theta_base = (float)p;
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t i0 = ic/2;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; if (!is_neox) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
float cos_theta, sin_theta; const float cos_theta = cache[i0 + 0];
rope_yarn( const float sin_theta = cache[i0 + 1];
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
theta_base *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[n_dims/2]; const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
} }
} }
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
} }
} }
} }
@ -14574,18 +14539,16 @@ static void ggml_compute_forward_rope_f16(
const int64_t p = pos[i2]; const int64_t p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox if (!is_glm) { // TODO: cache sin/cos for glm
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
} }
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
float theta_base = (float)p;
if (is_glm) { if (is_glm) {
theta_base = MIN(p, n_ctx - 2); float theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base); const float cos_theta = cosf(theta_base);
@ -14609,55 +14572,48 @@ static void ggml_compute_forward_rope_f16(
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
} }
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else { } else {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py const float theta_base = (float)p;
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t i0 = ic/2;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; if (!is_neox) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
float cos_theta, sin_theta; const float cos_theta = cache[i0 + 0];
rope_yarn( const float sin_theta = cache[i0 + 1];
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
theta_base *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]); const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
} }
} }
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
} }
} }
} }
@ -18361,7 +18317,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@ -18369,8 +18325,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_add_or_set(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
@ -18387,9 +18341,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ext_factor, ext_factor,
attn_factor, attn_factor,
beta_fast, beta_fast,
beta_slow, beta_slow),
xpos_base,
xpos_down),
zero_table); zero_table);
} }
} break; } break;
@ -18401,7 +18353,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@ -18409,8 +18361,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_add_or_set(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
@ -18428,8 +18378,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
attn_factor, attn_factor,
beta_fast, beta_fast,
beta_slow, beta_slow,
xpos_base,
xpos_down,
false), false),
zero_table); zero_table);
} }

12
ggml.h
View file

@ -1552,14 +1552,6 @@ extern "C" {
float beta_slow), float beta_slow),
"use ggml_rope_ext_inplace instead"); "use ggml_rope_ext_inplace instead");
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down);
// compute correction dims for YaRN RoPE scaling // compute correction dims for YaRN RoPE scaling
GGML_CALL void ggml_rope_yarn_corr_dims( GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
@ -1580,9 +1572,7 @@ extern "C" {
float ext_factor, float ext_factor,
float attn_factor, float attn_factor,
float beta_fast, float beta_fast,
float beta_slow, float beta_slow);
float xpos_base,
bool xpos_down);
// clamp // clamp
// in-place, returns view(a) // in-place, returns view(a)