fix rope op mode switching, out dated func args

This commit is contained in:
HimariO 2024-11-04 22:11:13 +08:00
parent f1fa60f84c
commit 241bb45714

View file

@ -3516,7 +3516,7 @@ static struct ggml_tensor * ggml_rope_impl(
} }
bool is_node = false; bool is_node = false;
int sections[3] = {0, 0, 0}; int sections[4] = {0, 0, 0, 0};
if (a->grad) { if (a->grad) {
is_node = true; is_node = true;
@ -3524,14 +3524,14 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[14] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &sections, sizeof(int) * 3); memcpy(params + 11, &sections, sizeof(int) * 4);
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
@ -11238,7 +11238,7 @@ static void ggml_rope_cache_init(
} }
static void ggml_mrope_cache_init( static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects, float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) { float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@ -11406,19 +11406,21 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
if (!is_neox) { if (is_neox || is_mrope) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0]; const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1]; const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[1]; const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
} else if (is_vision){ } else if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@ -11438,19 +11440,17 @@ static void ggml_compute_forward_rope_f32(
} }
} else { } else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0]; const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1]; const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0]; const float x0 = src[0];
const float x1 = src[n_dims/2]; const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} }
} }
@ -11590,19 +11590,21 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
if (!is_neox) { if (is_neox || is_mrope) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0]; const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1]; const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]); const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[1]); const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} }
} else if (is_vision){ } else if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@ -11622,19 +11624,17 @@ static void ggml_compute_forward_rope_f16(
} }
} else { } else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0]; const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1]; const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]); const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); const float x1 = GGML_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} }
} }