ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
This commit is contained in:
parent
9973e81c5c
commit
2b3389677a
19 changed files with 485 additions and 732 deletions
334
ggml.c
334
ggml.c
|
@ -6250,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||
|
||||
|
@ -6280,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
|
@ -6305,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
int mode) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6317,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
int mode) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6331,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
|
|||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -6340,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6352,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
|||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -6361,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6372,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -6381,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6392,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -6401,21 +6390,11 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
||||
struct ggml_tensor * ggml_rope_back(
|
||||
|
@ -6425,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
|
|||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float xpos_base,
|
||||
bool xpos_down) {
|
||||
float beta_slow) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
@ -6450,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
|
|||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE_BACK;
|
||||
|
@ -14227,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta
|
||||
) {
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
|
@ -14245,18 +14218,19 @@ static void rope_yarn(
|
|||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
}
|
||||
|
||||
static void ggml_rope_cache_init(
|
||||
float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale
|
||||
) {
|
||||
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
||||
rope_yarn(
|
||||
theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||
);
|
||||
cache[i0 + 1] *= sin_sign;
|
||||
|
||||
|
@ -14265,11 +14239,11 @@ static void ggml_rope_cache_init(
|
|||
}
|
||||
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
|
||||
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
|
||||
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
|
||||
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
|
||||
dims[0] = MAX(0, start);
|
||||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
|
@ -14289,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
// these two only relevant for xPos RoPE:
|
||||
float xpos_base;
|
||||
bool xpos_down;
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
|
@ -14305,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
@ -14336,20 +14304,15 @@ static void ggml_compute_forward_rope_f32(
|
|||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
|
@ -14364,95 +14327,51 @@ static void ggml_compute_forward_rope_f32(
|
|||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
const float x2 = src[n_dims];
|
||||
const float x3 = src[n_dims/2*3];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
|
||||
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
|
||||
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t i0 = ic/2;
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
sin_theta *= sin_sign;
|
||||
theta_base *= theta_scale;
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
const int64_t i0 = ic;
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14477,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
|
|||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
|
@ -14514,20 +14433,15 @@ static void ggml_compute_forward_rope_f16(
|
|||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
|
@ -14542,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
|
|||
const int64_t p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
||||
float theta_base = (float)p;
|
||||
|
||||
if (is_glm) {
|
||||
theta_base = MIN(p, n_ctx - 2);
|
||||
float block_theta = MAX(p - (n_ctx - 2), 0);
|
||||
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
||||
const float cos_theta = cosf(theta_base);
|
||||
const float sin_theta = sinf(theta_base) * sin_sign;
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||
|
||||
theta_base *= theta_scale;
|
||||
block_theta *= theta_scale;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
|
||||
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
|
||||
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
|
||||
}
|
||||
} else if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
|
@ -14592,41 +14477,30 @@ static void ggml_compute_forward_rope_f16(
|
|||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
} else {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||
if (ic < n_dims) {
|
||||
const int64_t i0 = ic/2;
|
||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||
const float cos_theta = cache[i0 + 0];
|
||||
const float sin_theta = cache[i0 + 1];
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
sin_theta *= sin_sign;
|
||||
theta_base *= theta_scale;
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
} else {
|
||||
const int64_t i0 = ic;
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -18327,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
||||
|
@ -18337,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
||||
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
|
@ -18348,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
n_orig_ctx,
|
||||
n_ctx_orig,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
xpos_base,
|
||||
xpos_down),
|
||||
beta_slow),
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
|
@ -18367,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
||||
|
@ -18377,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
||||
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
|
@ -18388,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
n_orig_ctx,
|
||||
n_ctx_orig,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
ext_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
xpos_base,
|
||||
xpos_down,
|
||||
false),
|
||||
zero_table);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue