From 99bb26078f6baec8bcb995bf09c61dee8b0d4c75 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 10:41:35 +0300 Subject: [PATCH] metal : implement RoPE (mode = 2) + avoid ggml_repeat --- ggml-metal.metal | 20 +++++++++++++++++++- llama.cpp | 18 ++++++------------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ce3541f4b..53604a250 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -571,7 +571,25 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: implement + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } } } diff --git a/llama.cpp b/llama.cpp index e801eb333..326852429 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2445,19 +2445,15 @@ static struct ggml_cgraph * llm_build_falcon( attn_norm = ggml_norm(ctx0, inpL, norm_eps); attn_norm = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm), - attn_norm), - ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); + ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), + model.layers[il].attn_norm_b); if (model.layers[il].attn_norm_2) { // Falcon-40B cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur), - cur), - ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur)); + ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), + model.layers[il].attn_norm_2_b); } else { // Falcon 7B cur = attn_norm; } @@ -2595,10 +2591,8 @@ static struct ggml_cgraph * llm_build_falcon( cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.output_norm, cur), - cur), - ggml_repeat(ctx0, model.output_norm_b, cur)); + ggml_mul(ctx0, cur, model.output_norm), + model.output_norm_b); ggml_set_name(cur, "result_norm"); }