From 241bb4571406049479d30e9e93de488158cad353 Mon Sep 17 00:00:00 2001 From: HimariO Date: Mon, 4 Nov 2024 22:11:13 +0800 Subject: [PATCH] fix rope op mode switching, out dated func args --- ggml/src/ggml.c | 60 ++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0cb53e13b..a38ee2e0f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3516,7 +3516,7 @@ static struct ggml_tensor * ggml_rope_impl( } bool is_node = false; - int sections[3] = {0, 0, 0}; + int sections[4] = {0, 0, 0, 0}; if (a->grad) { is_node = true; @@ -3524,14 +3524,14 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[14] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, §ions, sizeof(int) * 3); + memcpy(params + 11, §ions, sizeof(int) * 4); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -11238,7 +11238,7 @@ static void ggml_rope_cache_init( } static void ggml_mrope_cache_init( - float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects, + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py @@ -11406,19 +11406,21 @@ static void ggml_compute_forward_rope_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - if (!is_neox) { + if (is_neox || is_mrope) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); const float x0 = src[0]; - const float x1 = src[1]; + const float x1 = src[n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; } } else if (is_vision){ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { @@ -11438,19 +11440,17 @@ static void ggml_compute_forward_rope_f32( } } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[1]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; } } @@ -11590,19 +11590,21 @@ static void ggml_compute_forward_rope_f16( if (ir++ < ir0) continue; if (ir > ir1) break; - if (!is_neox) { + if (is_neox || is_mrope) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[1]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else if (is_vision){ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { @@ -11622,19 +11624,17 @@ static void ggml_compute_forward_rope_f16( } } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x1 = GGML_FP16_TO_FP32(src[1]); - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } }