From fe788c45c8d5a2c6e6c2f0ec04978419555c65b2 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 21 Sep 2023 00:01:48 -0400 Subject: [PATCH] don't hardcode max_pos_emb --- .../train-text-from-scratch.cpp | 2 +- ggml-cuda.cu | 27 ++--- ggml-metal.m | 19 ++-- ggml-metal.metal | 17 +-- ggml.c | 100 ++++++++++-------- ggml.h | 7 +- llama.cpp | 67 +++++++----- 7 files changed, 131 insertions(+), 108 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 8a5ad82bf..36415398e 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -687,7 +687,7 @@ struct ggml_tensor * llama_build_train_graphs( const int rope_mode = 0; return ggml_rope_custom( - ctx, t, n_past, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f + ctx, t, n_past, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 394fd81fa..bd788ce4d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4386,7 +4386,7 @@ static __device__ void rope_yarn( // rope == RoPE == rotary positional embedding static __global__ void rope_f32( - float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale, + const float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -5396,7 +5396,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons } static void rope_f32_cuda( - float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor, + const float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor, float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); @@ -6109,19 +6109,20 @@ inline void ggml_cuda_op_rope( const int64_t ne01 = src0->ne[1]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); const float p0 = (mode & 1) == 0 ? n_past : 0; @@ -6137,7 +6138,7 @@ inline void ggml_cuda_op_rope( rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); } else { rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v); + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); rope_f32_cuda( src0_dd, dst_dd, ne00, nrows, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01, corr_dims, diff --git a/ggml-metal.m b/ggml-metal.m index 06d97695b..c0607e844 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1176,17 +1176,18 @@ void ggml_metal_graph_compute( } break; case GGML_OP_ROPE: { - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; diff --git a/ggml-metal.metal b/ggml-metal.metal index c5e0ee8a0..ddf81fb7b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -832,18 +832,18 @@ static void rope_yarn( *sin_theta = sinf(theta) * mscale; } -constant float max_pos_emb = 2048; - // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(const int n_dims, const float n_rot, const float base) { - return n_dims * log(max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); } -static void rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) { +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); } kernel void kernel_rope( @@ -868,6 +868,7 @@ kernel void kernel_rope( constant int & n_past, constant int & n_dims, constant int & mode, + constant int & n_orig_ctx, constant float & freq_base, constant float & freq_scale, constant float & ext_factor, @@ -884,7 +885,7 @@ kernel void kernel_rope( const bool is_neox = mode & 2; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const int64_t p = (mode & 1) == 0 ? n_past + i2 : i2; diff --git a/ggml.c b/ggml.c index 53137924d..56b9fdd29 100644 --- a/ggml.c +++ b/ggml.c @@ -6973,6 +6973,7 @@ static struct ggml_tensor * ggml_rope_impl( int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, @@ -6991,15 +6992,15 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[12] = { n_past, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &ext_factor, sizeof(float)); - memcpy(params + 7, &attn_factor, sizeof(float)); - memcpy(params + 8, &beta_fast, sizeof(float)); - memcpy(params + 9, &beta_slow, sizeof(float)); - memcpy(params + 10, &xpos_base, sizeof(float)); - memcpy(params + 11, &xpos_down, sizeof(bool)); + int32_t params[13] = { n_past, n_dims, mode, n_ctx, n_orig_ctx }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, &xpos_base, sizeof(float)); + memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -7017,7 +7018,7 @@ struct ggml_tensor * ggml_rope( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false ); } @@ -7029,7 +7030,7 @@ struct ggml_tensor * ggml_rope_inplace( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true ); } @@ -7040,6 +7041,7 @@ struct ggml_tensor * ggml_rope_custom( int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, @@ -7047,8 +7049,8 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, - false, false + ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ); } @@ -7059,6 +7061,7 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, @@ -7066,8 +7069,8 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, - false, true + ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ); } @@ -7078,7 +7081,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( int n_dims, float base, bool down) { - return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); } // ggml_rope_back @@ -12675,15 +12678,16 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_yarn_corr_dim(const int n_dims, const float n_rot, const float base) { - static const float max_pos_emb = 2048; - return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } -void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) { +void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { // start and end correction dims - dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, beta_fast, freq_base))); - dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, beta_slow, freq_base))); + dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); } static void ggml_compute_forward_rope_f32( @@ -12701,18 +12705,20 @@ static void ggml_compute_forward_rope_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 11, sizeof(bool)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); assert(n_past >= 0); @@ -12743,7 +12749,7 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12844,16 +12850,17 @@ static void ggml_compute_forward_rope_f16( float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); assert(n_past >= 0); @@ -12884,7 +12891,7 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(freq_base, -2.0f/n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -16641,6 +16648,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor n_past, n_dims, mode, + 0, n_ctx, freq_base, freq_scale, diff --git a/ggml.h b/ggml.h index 5078fb7b5..26f7cf024 100644 --- a/ggml.h +++ b/ggml.h @@ -219,7 +219,7 @@ #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 #define GGML_MAX_NAME 64 -#define GGML_MAX_OP_PARAMS 48 +#define GGML_MAX_OP_PARAMS 64 #define GGML_DEFAULT_N_THREADS 4 #if UINTPTR_MAX == 0xFFFFFFFF @@ -1248,6 +1248,7 @@ extern "C" { int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, @@ -1263,6 +1264,7 @@ extern "C" { int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, float ext_factor, @@ -1271,7 +1273,8 @@ extern "C" { float beta_slow); // compute correction dims for YaRN RoPE scaling - void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]); + void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // xPos RoPE, in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( diff --git a/llama.cpp b/llama.cpp index 56c511b59..7184c376c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2471,13 +2471,14 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(!!kv_self.ctx); - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int32_t n_embd = hparams.n_embd; + const int32_t n_layer = hparams.n_layer; + const int32_t n_ctx = hparams.n_ctx; + const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_head = hparams.n_head; + const int32_t n_head_kv = hparams.n_head_kv; + const int32_t n_embd_head = hparams.n_embd_head(); + const int32_t n_embd_gqa = hparams.n_embd_gqa(); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -2599,14 +2600,18 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(tmpq, "tmpq"); struct ggml_tensor * Kcur = ggml_rope_custom_inplace( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), + n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_custom_inplace( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), + n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -2811,13 +2816,14 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(!!kv_self.ctx); - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int32_t n_embd = hparams.n_embd; + const int32_t n_layer = hparams.n_layer; + const int32_t n_ctx = hparams.n_ctx; + const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_head = hparams.n_head; + const int32_t n_head_kv = hparams.n_head_kv; + const int32_t n_embd_head = hparams.n_embd_head(); + const int32_t n_embd_gqa = hparams.n_embd_gqa(); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -2945,12 +2951,14 @@ static struct ggml_cgraph * llm_build_baichaun( Kcur = ggml_rope_custom_inplace( ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), - n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); Qcur = ggml_rope_custom_inplace( ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), - n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); break; case MODEL_13B: @@ -3183,13 +3191,14 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(!!kv_self.ctx); - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int32_t n_embd = hparams.n_embd; + const int32_t n_layer = hparams.n_layer; + const int32_t n_ctx = hparams.n_ctx; + const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_head = hparams.n_head; + const int32_t n_head_kv = hparams.n_head_kv; + const int32_t n_embd_head = hparams.n_embd_head(); + const int32_t n_embd_gqa = hparams.n_embd_gqa(); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -3351,12 +3360,12 @@ static struct ggml_cgraph * llm_build_falcon( // using mode = 2 for neox mode struct ggml_tensor * Qcur = ggml_rope_custom_inplace( - ctx0, tmpq, n_past, n_embd_head, 2, 0, + ctx0, tmpq, n_past, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); offload_func_kq(Qcur); struct ggml_tensor * Kcur = ggml_rope_custom_inplace( - ctx0, tmpk, n_past, n_embd_head, 2, 0, + ctx0, tmpk, n_past, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); offload_func_kq(Kcur);