From a30ae2095c91adbb9e3d626c1bae234fcb43e669 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Mon, 4 Sep 2023 20:08:17 -0400 Subject: [PATCH] implement new YaRN algorithm --- examples/common.cpp | 68 +++++++++------- examples/common.h | 4 +- examples/server/server.cpp | 36 ++++++--- ggml-cuda.cu | 79 +++++++++---------- ggml-metal.m | 22 +++--- ggml-metal.metal | 70 +++++++---------- ggml.c | 156 +++++++++++++++++-------------------- ggml.h | 16 ++-- llama.cpp | 111 +++++++++++++++----------- llama.h | 10 ++- 10 files changed, 303 insertions(+), 269 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 957022d0c..ef08c403c 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -194,18 +194,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = std::stof(argv[i]); - } else if (arg == "--rope-ntk-factor") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_ntk_factor = std::stof(argv[i]); } else if (arg == "--rope-ext-factor") { if (++i >= argc) { invalid_param = true; break; } params.rope_ext_factor = std::stof(argv[i]); + } else if (arg == "--rope-attn-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_attn_factor = std::stof(argv[i]); + } else if (arg == "--rope-beta-fast") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_beta_fast = std::stof(argv[i]); + } else if (arg == "--rope-beta-slow") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_beta_slow = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -578,8 +590,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); - fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor); fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); + fprintf(stdout, " --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor); + fprintf(stdout, " --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast); + fprintf(stdout, " --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n"); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); @@ -654,25 +668,27 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_batch = params.n_batch; - lparams.n_gqa = params.n_gqa; - lparams.rms_norm_eps = params.rms_norm_eps; - lparams.n_gpu_layers = params.n_gpu_layers; - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.mul_mat_q = params.mul_mat_q; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; - lparams.rope_freq_base = params.rope_freq_base; - lparams.rope_freq_scale = params.rope_freq_scale; - lparams.rope_ntk_factor = params.rope_ntk_factor; - lparams.rope_ext_factor = params.rope_ext_factor; + lparams.n_ctx = params.n_ctx; + lparams.n_batch = params.n_batch; + lparams.n_gqa = params.n_gqa; + lparams.rms_norm_eps = params.rms_norm_eps; + lparams.n_gpu_layers = params.n_gpu_layers; + lparams.main_gpu = params.main_gpu; + lparams.tensor_split = params.tensor_split; + lparams.low_vram = params.low_vram; + lparams.mul_mat_q = params.mul_mat_q; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; + lparams.rope_freq_base = params.rope_freq_base; + lparams.rope_freq_scale = params.rope_freq_scale; + lparams.rope_ext_factor = params.rope_ext_factor; + lparams.rope_attn_factor = params.rope_attn_factor; + lparams.rope_beta_fast = params.rope_beta_fast; + lparams.rope_beta_slow = params.rope_beta_slow; return lparams; } diff --git a/examples/common.h b/examples/common.h index 677676ad1..8410b38a5 100644 --- a/examples/common.h +++ b/examples/common.h @@ -32,8 +32,10 @@ struct gpt_params { float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor - float rope_ntk_factor = 0.0f; // RoPE NTK mix factor float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor + float rope_attn_factor = 1.0f; // RoPE magnitude scaling factor + float rope_beta_fast = 32.0f; // RoPE low correction dim + float rope_beta_slow = 1.0f; // RoPE high correction dim // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 49d2dd050..9721f2692 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -612,8 +612,10 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); - fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor); fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); + fprintf(stdout, " --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor); + fprintf(stdout, " --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast); + fprintf(stdout, " --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -766,14 +768,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.rope_freq_scale = std::stof(argv[i]); } - else if (arg == "--rope-ntk-factor") - { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_ntk_factor = std::stof(argv[i]); - } else if (arg == "--rope-ext-factor") { if (++i >= argc) { @@ -782,6 +776,30 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.rope_ext_factor = std::stof(argv[i]); } + else if (arg == "--rope-attn-factor") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_attn_factor = std::stof(argv[i]); + } + else if (arg == "--rope-beta-fast") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_beta_fast = std::stof(argv[i]); + } + else if (arg == "--rope-beta-slow") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_beta_slow = std::stof(argv[i]); + } else if (arg == "--memory-f32" || arg == "--memory_f32") { params.memory_f16 = false; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 91a6edca6..dedb87efd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3558,34 +3558,31 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -static __device__ float rope_ntkv2_ramp(const float low, const float high, const int i0) { +static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / min(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } -struct rope_corr_factors { +struct rope_corr_dims { float v[4]; }; -// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static __device__ float rope_ntkv2( - const float theta_base, - const float theta_linear, - const float theta_ntk, - const rope_corr_factors corr_factors, - const int64_t i0, - const float ntk_factor, - const float ext_factor) { - float ramp_mix; - float theta; +static __device__ void rope_yarn( + float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - ramp_mix = rope_ntkv2_ramp(corr_factors.v[0], corr_factors.v[1], i0) * ntk_factor; - theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - - ramp_mix = rope_ntkv2_ramp(corr_factors.v[2], corr_factors.v[3], i0) * ext_factor; - theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; - return theta; + // Get n-d magnitude scaling corrected for interpolation + if (freq_scale > 1.0f) + mscale *= 1.0f + 0.1f * logf(freq_scale); + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; } // rope == RoPE == rotary positional embedding @@ -3594,13 +3591,11 @@ static __global__ void rope_f32( float * dst, const int ncols, const float freq_scale, - const float ntk_factor, const float ext_factor, const float theta_scale, - const float theta_ntk_scale, const float p0, const int p_delta_rows, - const rope_corr_factors corr_factors) { + const rope_corr_dims corr_dims) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (col >= ncols) { @@ -3612,11 +3607,9 @@ static __global__ void rope_f32( const float p = p0 + row / p_delta_rows; const float theta_base = p*powf(theta_scale, col/2); - const float theta_linear = freq_scale * theta_base; - const float theta_ntk = p*powf(theta_ntk_scale, col/2); - const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor, ext_factor); - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); + + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + 1]; @@ -4284,20 +4277,19 @@ static void rope_f32_cuda( const int ncols, const int nrows, const float freq_scale, - const float ntk_factor, const float ext_factor, const float theta_scale, - const float theta_ntk_scale, const float p0, const int p_delta_rows, - const rope_corr_factors corr_factors, + const rope_corr_dims corr_dims, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(num_blocks_x, nrows, 1); - rope_f32<<>>(x, dst, ncols, freq_scale, ntk_factor, ext_factor, theta_scale, - theta_ntk_scale, p0, p_delta_rows, corr_factors); + rope_f32<<>>( + x, dst, ncols, freq_scale, ext_factor, theta_scale, p0, p_delta_rows, corr_dims + ); } static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { @@ -5000,11 +4992,13 @@ inline void ggml_cuda_op_rope( const int n_ctx = ((int32_t *) dst->op_params)[3]; // RoPE alteration for extended context - float freq_base, freq_scale, ntk_factor, ext_factor; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); @@ -5018,12 +5012,13 @@ inline void ggml_cuda_op_rope( rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); } else { const float p0 = (mode & 1) == 0 ? n_past : 0; - const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - rope_corr_factors corr_factors; - ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors.v); + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v); - rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale, - theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main); + rope_f32_cuda( + src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, theta_scale, p0, ne01, corr_dims, + cudaStream_main + ); } (void) src1; diff --git a/ggml-metal.m b/ggml-metal.m index 372d3e696..0e8b0f9a9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1035,11 +1035,13 @@ void ggml_metal_graph_compute( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - float freq_base, freq_scale, ntk_factor, ext_factor; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1063,10 +1065,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; - [encoder setBytes:&ntk_factor length:sizeof(float) atIndex:23]; - [encoder setBytes:&ext_factor length:sizeof(float) atIndex:24]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + [encoder setBytes:&ext_factor length:sizeof(float) atIndex:23]; + [encoder setBytes:&attn_factor length:sizeof(float) atIndex:24]; + [encoder setBytes:&beta_fast length:sizeof(float) atIndex:25]; + [encoder setBytes:&beta_slow length:sizeof(float) atIndex:26]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 347fd17ac..f5a98d092 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -597,53 +597,41 @@ kernel void kernel_alibi_f32( } } -static float rope_ntkv2_ramp(const float low, const float high, const int i0) { +static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / min(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } -// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static float rope_ntkv2( - const float theta_base, - const float theta_linear, - const float theta_ntk, - const float corr_factors[4], - const int64_t i0, - const float ntk_factor, - const float ext_factor) { - float ramp_mix; - float theta; +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; - theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - - ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor; - theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; - return theta; + // Get n-d magnitude scaling corrected for interpolation + if (freq_scale > 1.0f) + mscale *= 1.0f + 0.1f * logf(freq_scale); + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; } -// Interpolation constants found experimentally for LLaMA (might not be totally optimal though) -// Do not change unless there is a good reason for doing so! -constant float BETA_0 = 1.75f; -constant float BETA_1 = 1.25f; -constant float GAMMA_0 = 16.0f; -constant float GAMMA_1 = 2.0f; - constant float max_pos_emb = 2048; // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) { +static float rope_yarn_corr_factor(const int n_dims, const float n_rot, const float base) { return n_dims * log(max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log(base)); } -static void rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) { - // start and end correction factors - factors[0] = max(0.0f, floor(rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base))); - factors[1] = min(n_dims - 1.0f, ceil(rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base))); - factors[2] = max(0.0f, floor(rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base))); - factors[3] = min(n_dims - 1.0f, ceil(rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base))); +static void rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, beta_slow, freq_base))); } kernel void kernel_rope( @@ -670,33 +658,29 @@ kernel void kernel_rope( constant int & mode, constant float & freq_base, constant float & freq_scale, - constant float & ntk_factor, constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, uint3 tpig[[thread_position_in_grid]]) { const int64_t i3 = tpig[2]; const int64_t i2 = tpig[1]; const int64_t i1 = tpig[0]; const float theta_scale = pow(freq_base, -2.0f/n_dims); - const float theta_ntk_scale = pow(freq_base * pow(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - float corr_factors[4]; - rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); float theta_base = (mode & 1) == 0 ? n_past + i2 : i2; - float theta_ntk = theta_base; const bool is_neox = mode & 2; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta_linear = freq_scale * theta_base; - const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, - i0, ntk_factor, ext_factor); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); theta_base *= theta_scale; - theta_ntk *= theta_ntk_scale; device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml.c b/ggml.c index 8a5739181..8bf7f52e2 100644 --- a/ggml.c +++ b/ggml.c @@ -6712,8 +6712,10 @@ static struct ggml_tensor * ggml_rope_impl( int n_ctx, float freq_base, float freq_scale, - float ntk_factor, float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6724,11 +6726,13 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &ntk_factor, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); + int32_t params[10] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &ext_factor, sizeof(float)); + memcpy(params + 7, &attn_factor, sizeof(float)); + memcpy(params + 8, &beta_fast, sizeof(float)); + memcpy(params + 9, &beta_slow, sizeof(float)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6745,7 +6749,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 0.0f, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6755,7 +6759,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 0.0f, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true); } struct ggml_tensor * ggml_rope_custom( @@ -6767,9 +6771,13 @@ struct ggml_tensor * ggml_rope_custom( int n_ctx, float freq_base, float freq_scale, - float ntk_factor, - float ext_factor) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ntk_factor, ext_factor, false); + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, false + ); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -6781,9 +6789,13 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_ctx, float freq_base, float freq_scale, - float ntk_factor, - float ext_factor) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ntk_factor, ext_factor, true); + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, true + ); } // ggml_rope_back @@ -12012,52 +12024,40 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope -static inline float rope_ntkv2_ramp(const float low, const float high, const int i0) { +static inline float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / MIN(0.001f, high - low); return 1 - MIN(1, MAX(0, y)); } -// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static float rope_ntkv2( - const float theta_base, - const float theta_linear, - const float theta_ntk, - const float corr_factors[4], - const int64_t i0, - const float ntk_factor, - const float ext_factor) { - float ramp_mix; - float theta; +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + float theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; - theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - - ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor; - theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; - return theta; + // Get n-d magnitude scaling corrected for interpolation + if (freq_scale > 1.0f) + mscale *= 1.0f + 0.1f * logf(freq_scale); + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; } // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) { +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(const int n_dims, const float n_rot, const float base) { static const float max_pos_emb = 2048; return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } -void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) { - // Interpolation constants found experimentally for LLaMA (might not be totally optimal though) - // Do not change unless there is a good reason for doing so! - static const float BETA_0 = 1.75f; - static const float BETA_1 = 1.25f; - static const float GAMMA_0 = 16.0f; - static const float GAMMA_1 = 2.0f; - - // start and end correction factors - factors[0] = MAX(0, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base))); - factors[1] = MIN(n_dims - 1, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base))); - factors[2] = MAX(0, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base))); - factors[3] = MIN(n_dims - 1, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base))); +void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) { + // start and end correction dims + dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, beta_slow, freq_base))); } static void ggml_compute_forward_rope_f32( @@ -12069,19 +12069,18 @@ static void ggml_compute_forward_rope_f32( return; } - float freq_base; - float freq_scale; - float ntk_factor; - float ext_factor; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); assert(n_past >= 0); @@ -12111,9 +12110,8 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - float corr_factors[4]; - ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12126,7 +12124,6 @@ static void ggml_compute_forward_rope_f32( if (ir > ir1) break; float theta_base = (float)p; - float theta_ntk = theta_base; if (is_glm) { theta_base = MIN(p, n_ctx - 2); @@ -12155,14 +12152,12 @@ static void ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta_linear = freq_scale * theta_base; - const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, - i0, ntk_factor, ext_factor); - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); theta_base *= theta_scale; - theta_ntk *= theta_ntk_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12211,19 +12206,18 @@ static void ggml_compute_forward_rope_f16( return; } - float freq_base; - float freq_scale; - float ntk_factor; - float ext_factor; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float)); assert(n_past >= 0); @@ -12253,9 +12247,8 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - float corr_factors[4]; - ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12268,7 +12261,6 @@ static void ggml_compute_forward_rope_f16( if (ir > ir1) break; float theta_base = (float)p; - float theta_ntk = theta_base; if (is_glm) { theta_base = MIN(p, n_ctx - 2); @@ -12297,14 +12289,12 @@ static void ggml_compute_forward_rope_f16( } } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float theta_linear = freq_scale * theta_base; - const float theta = rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, - i0, ntk_factor, ext_factor); - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); theta_base *= theta_scale; - theta_ntk *= theta_ntk_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml.h b/ggml.h index c2c6b7b1d..06e1cbf7b 100644 --- a/ggml.h +++ b/ggml.h @@ -1195,8 +1195,10 @@ extern "C" { int n_ctx, float freq_base, float freq_scale, - float ntk_factor, - float ext_factor); + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_custom_inplace( @@ -1208,11 +1210,13 @@ extern "C" { int n_ctx, float freq_base, float freq_scale, - float ntk_factor, - float ext_factor); + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); - // compute correction factors for NTKv2 RoPE scaling - void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]); + // compute correction dims for YaRN RoPE scaling + void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy diff --git a/llama.cpp b/llama.cpp index 945215de8..3ea065082 100644 --- a/llama.cpp +++ b/llama.cpp @@ -194,10 +194,12 @@ struct llama_hparams { float f_ffn_mult = 1.0f; float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; - float rope_freq_base = 10000.0f; - float rope_freq_scale = 1.0f; - float rope_ntk_factor = 0.0f; - float rope_ext_factor = 0.0f; + float rope_freq_base = 10000.0f; + float rope_freq_scale = 1.0f; + float rope_ext_factor = 0.0f; + float rope_attn_factor = 1.0f; + float rope_beta_fast = 0.0f; + float rope_beta_slow = 0.0f; enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; @@ -900,8 +902,10 @@ struct llama_context_params llama_context_default_params() { /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 10000.0f, /*.rope_freq_scale =*/ 1.0f, - /*.rope_ntk_factor =*/ 0.0f, /*.rope_ext_factor =*/ 0.0f, + /*.rope_attn_factor =*/ 1.0f, + /*.rope_beta_fast =*/ 32.0f, + /*.rope_beta_slow =*/ 1.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, @@ -1036,8 +1040,10 @@ static void llama_model_load_internal( const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, - float rope_ntk_factor, float rope_ext_factor, + float rope_attn_factor, + float rope_beta_fast, + float rope_beta_slow, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1087,10 +1093,12 @@ static void llama_model_load_internal( hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model } - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; - hparams.rope_ntk_factor = rope_ntk_factor; - hparams.rope_ext_factor = rope_ext_factor; + hparams.rope_freq_base = rope_freq_base; + hparams.rope_freq_scale = rope_freq_scale; + hparams.rope_ext_factor = rope_ext_factor; + hparams.rope_attn_factor = rope_attn_factor; + hparams.rope_beta_fast = rope_beta_fast; + hparams.rope_beta_slow = rope_beta_slow; } // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 @@ -1100,24 +1108,26 @@ static void llama_model_load_internal( //const uint32_t n_ff = 28672; { - fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); - fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); - fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); - fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); - fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); - fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - fprintf(stderr, "%s: ntk_factor = %g\n", __func__, hparams.rope_ntk_factor); - fprintf(stderr, "%s: ext_factor = %g\n", __func__, hparams.rope_ext_factor); - fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); + fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); + fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + fprintf(stderr, "%s: ext_factor = %g\n", __func__, hparams.rope_ext_factor); + fprintf(stderr, "%s: attn_factor = %g\n", __func__, hparams.rope_attn_factor); + fprintf(stderr, "%s: beta_fast = %g\n", __func__, hparams.rope_beta_fast); + fprintf(stderr, "%s: beta_slow = %g\n", __func__, hparams.rope_beta_slow); + fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); + fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { @@ -1384,8 +1394,10 @@ static bool llama_model_load( const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, - float rope_ntk_factor, float rope_ext_factor, + float rope_attn_factor, + float rope_beta_fast, + float rope_beta_slow, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1394,10 +1406,11 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, - tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, rope_ntk_factor, - rope_ext_factor, low_vram, memory_type, use_mmap, use_mlock, vocab_only, - progress_callback, progress_callback_user_data); + llama_model_load_internal( + fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, + rope_freq_base, rope_freq_scale, rope_ext_factor, rope_attn_factor, rope_beta_fast, rope_beta_slow, + low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data + ); return true; } catch (const std::exception & err) { fprintf(stderr, "error loading model: %s\n", err.what()); @@ -1433,10 +1446,12 @@ static struct ggml_cgraph * llama_build_graph( LLAMA_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; - const float ntk_factor = hparams.rope_ntk_factor; - const float ext_factor = hparams.rope_ext_factor; + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float ext_factor = hparams.rope_ext_factor; + const float attn_factor = hparams.rope_attn_factor; + const float beta_fast = hparams.rope_beta_fast; + const float beta_slow = hparams.rope_beta_slow; const float rms_norm_eps = hparams.f_rms_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -1567,14 +1582,16 @@ static struct ggml_cgraph * llama_build_graph( ggml_set_name(tmpq, "tmpq"); struct ggml_tensor * Kcur = ggml_rope_custom_inplace( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ntk_factor, ext_factor); + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_custom_inplace( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ntk_factor, ext_factor); + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -3216,11 +3233,13 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, - params.rms_norm_eps, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, - params.rope_freq_base, params.rope_freq_scale, params.rope_ntk_factor, params.rope_ext_factor, - params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, - params.progress_callback, params.progress_callback_user_data)) { + if (!llama_model_load( + path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, + params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, + params.rope_freq_scale, params.rope_ext_factor, params.rope_attn_factor, params.rope_beta_fast, + params.rope_beta_slow, params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data + )) { delete model; fprintf(stderr, "%s: failed to load model\n", __func__); return nullptr; diff --git a/llama.h b/llama.h index 25bb3952a..66c78c761 100644 --- a/llama.h +++ b/llama.h @@ -98,10 +98,12 @@ extern "C" { const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency - float rope_freq_scale; // RoPE frequency scaling factor - float rope_ntk_factor; // RoPE NTK mix factor - float rope_ext_factor; // RoPE extrapolation mix factor + float rope_freq_base; // RoPE base frequency + float rope_freq_scale; // RoPE frequency scaling factor + float rope_ext_factor; // RoPE extrapolation mix factor + float rope_attn_factor; // RoPE magnitude scaling factor + float rope_beta_fast; // RoPE low correction dim + float rope_beta_slow; // RoPE high correction dim // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback;