diff --git a/examples/common.cpp b/examples/common.cpp index 21f4a0357..957022d0c 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -194,6 +194,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-ntk-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_ntk_factor = std::stof(argv[i]); + } else if (arg == "--rope-ext-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_ext_factor = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -566,6 +578,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); + fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor); + fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n"); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); @@ -657,6 +671,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.embedding = params.embedding; lparams.rope_freq_base = params.rope_freq_base; lparams.rope_freq_scale = params.rope_freq_scale; + lparams.rope_ntk_factor = params.rope_ntk_factor; + lparams.rope_ext_factor = params.rope_ext_factor; return lparams; } diff --git a/examples/common.h b/examples/common.h index 375bc0a3d..677676ad1 100644 --- a/examples/common.h +++ b/examples/common.h @@ -32,6 +32,8 @@ struct gpt_params { float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor + float rope_ntk_factor = 0.0f; // RoPE NTK mix factor + float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6f7a66da1..49d2dd050 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -612,6 +612,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); + fprintf(stdout, " --rope-ntk-factor N RoPE NTK mix factor (default: %.1f)\n", params.rope_ntk_factor); + fprintf(stdout, " --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -764,6 +766,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.rope_freq_scale = std::stof(argv[i]); } + else if (arg == "--rope-ntk-factor") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_ntk_factor = std::stof(argv[i]); + } + else if (arg == "--rope-ext-factor") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_ext_factor = std::stof(argv[i]); + } else if (arg == "--memory-f32" || arg == "--memory_f32") { params.memory_f16 = false; diff --git a/ggml.c b/ggml.c index beb7f4641..8c5f7ac26 100644 --- a/ggml.c +++ b/ggml.c @@ -1,5 +1,6 @@ #define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC #include "ggml.h" @@ -6711,6 +6712,8 @@ static struct ggml_tensor * ggml_rope_impl( int n_ctx, float freq_base, float freq_scale, + float ntk_factor, + float ext_factor, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6721,9 +6724,11 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[6] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &ntk_factor, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6740,7 +6745,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 0.0f, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6750,7 +6755,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 0.0f, true); } struct ggml_tensor * ggml_rope_custom( @@ -6761,8 +6766,10 @@ struct ggml_tensor * ggml_rope_custom( int mode, int n_ctx, float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false); + float freq_scale, + float ntk_factor, + float ext_factor) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ntk_factor, ext_factor, false); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -6773,8 +6780,10 @@ struct ggml_tensor * ggml_rope_custom_inplace( int mode, int n_ctx, float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); + float freq_scale, + float ntk_factor, + float ext_factor) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ntk_factor, ext_factor, true); } // ggml_rope_back @@ -12003,6 +12012,52 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +#define NTKV2_MAX_POS_EMB 2048 +#define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2) + +static inline float rope_ntkv2_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MIN(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static float rope_ntkv2( + const float theta_base, + const float theta_ntk, + const float dims_over_base, + const float freq_scale, + const int64_t i0, + const float ntk_factor, + const float ext_factor, + const int n_dims) { + // Interpolation constants found experimentally for LLaMA (might not be totally optimal though) + // Do not change unless there is a good reason for doing so! + static const float BETA_0 = 1.75f; + static const float BETA_1 = 1.25f; + static const float GAMMA_0 = 16.0f; + static const float GAMMA_1 = 2.0f; + + static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0); + static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1); + static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0); + static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1); + + // start and end correction factors + const float low_1 = MAX(0, floorf(low_1p * dims_over_base)); + const float high_1 = MIN(n_dims - 1, ceilf(high_1p * dims_over_base)); + const float low_2 = MAX(0, floorf(low_2p * dims_over_base)); + const float high_2 = MIN(n_dims - 1, ceilf(high_2p * dims_over_base)); + + const float theta_linear = freq_scale * theta_base; + const float ramp_mix = rope_ntkv2_ramp(low_1, high_1, i0) * ntk_factor; + const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; + const float ramp_final = rope_ntkv2_ramp(low_2, high_2, i0) * ext_factor; + return theta_mix * (1 - ramp_final) + theta_base * ramp_final; +} + static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -12014,6 +12069,8 @@ static void ggml_compute_forward_rope_f32( float freq_base; float freq_scale; + float ntk_factor; + float ext_factor; const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -12021,6 +12078,8 @@ static void ggml_compute_forward_rope_f32( const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); assert(n_past >= 0); @@ -12050,6 +12109,8 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); + const float dims_over_base = n_dims / logf(freq_base); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12061,18 +12122,19 @@ static void ggml_compute_forward_rope_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; + float theta_ntk = theta_base; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12090,10 +12152,13 @@ static void ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float theta = rope_ntkv2(theta_base, theta_ntk, dims_over_base, + freq_scale, i0, ntk_factor, ext_factor, n_dims); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); - theta *= theta_scale; + theta_base *= theta_scale; + theta_ntk *= theta_ntk_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12107,12 +12172,13 @@ static void ggml_compute_forward_rope_f32( } else { // TODO: this is probably wrong, but I can't figure it out .. // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12143,6 +12209,8 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; + float ntk_factor; + float ext_factor; const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -12150,6 +12218,8 @@ static void ggml_compute_forward_rope_f16( const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&ntk_factor, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); assert(n_past >= 0); @@ -12179,6 +12249,8 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); + const float dims_over_base = n_dims / logf(freq_base); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12190,18 +12262,19 @@ static void ggml_compute_forward_rope_f16( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; + float theta_ntk = theta_base; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12219,10 +12292,13 @@ static void ggml_compute_forward_rope_f16( } } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float theta = rope_ntkv2(theta_base, theta_ntk, dims_over_base, + freq_scale, i0, ntk_factor, ext_factor, n_dims); const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); - theta *= theta_scale; + theta_base *= theta_scale; + theta_ntk *= theta_ntk_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12236,12 +12312,13 @@ static void ggml_compute_forward_rope_f16( } else { // TODO: this is probably wrong, but I can't figure it out .. // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12335,14 +12412,14 @@ static void ggml_compute_forward_rope_back_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta_base = (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12356,10 +12433,10 @@ static void ggml_compute_forward_rope_back_f32( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12431,14 +12508,14 @@ static void ggml_compute_forward_rope_back_f16( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta_base = (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12452,10 +12529,10 @@ static void ggml_compute_forward_rope_back_f16( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; diff --git a/ggml.h b/ggml.h index bdbd12800..459d217df 100644 --- a/ggml.h +++ b/ggml.h @@ -1194,7 +1194,9 @@ extern "C" { int mode, int n_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ntk_factor, + float ext_factor); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_custom_inplace( @@ -1205,7 +1207,9 @@ extern "C" { int mode, int n_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ntk_factor, + float ext_factor); // rotary position embedding backward, i.e compute dx from dy // a - dy diff --git a/llama.cpp b/llama.cpp index 39aefd499..945215de8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -196,6 +196,8 @@ struct llama_hparams { float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; + float rope_ntk_factor = 0.0f; + float rope_ext_factor = 0.0f; enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; @@ -898,6 +900,8 @@ struct llama_context_params llama_context_default_params() { /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 10000.0f, /*.rope_freq_scale =*/ 1.0f, + /*.rope_ntk_factor =*/ 0.0f, + /*.rope_ext_factor =*/ 0.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, @@ -1032,6 +1036,8 @@ static void llama_model_load_internal( const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, + float rope_ntk_factor, + float rope_ext_factor, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1083,6 +1089,8 @@ static void llama_model_load_internal( hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; + hparams.rope_ntk_factor = rope_ntk_factor; + hparams.rope_ext_factor = rope_ext_factor; } // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 @@ -1106,6 +1114,8 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + fprintf(stderr, "%s: ntk_factor = %g\n", __func__, hparams.rope_ntk_factor); + fprintf(stderr, "%s: ext_factor = %g\n", __func__, hparams.rope_ext_factor); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } @@ -1374,6 +1384,8 @@ static bool llama_model_load( const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, + float rope_ntk_factor, + float rope_ext_factor, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1382,9 +1394,10 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, - use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, + tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, rope_ntk_factor, + rope_ext_factor, low_vram, memory_type, use_mmap, use_mlock, vocab_only, + progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { fprintf(stderr, "error loading model: %s\n", err.what()); @@ -1422,6 +1435,8 @@ static struct ggml_cgraph * llama_build_graph( const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const float ntk_factor = hparams.rope_ntk_factor; + const float ext_factor = hparams.rope_ext_factor; const float rms_norm_eps = hparams.f_rms_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -1551,11 +1566,15 @@ static struct ggml_cgraph * llama_build_graph( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, + freq_scale, ntk_factor, ext_factor); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, + freq_scale, ntk_factor, ext_factor); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -3197,10 +3216,11 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram, - memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, - params.progress_callback_user_data)) { + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, + params.rms_norm_eps, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, + params.rope_freq_base, params.rope_freq_scale, params.rope_ntk_factor, params.rope_ext_factor, + params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { delete model; fprintf(stderr, "%s: failed to load model\n", __func__); return nullptr; diff --git a/llama.h b/llama.h index fa1977f2d..25bb3952a 100644 --- a/llama.h +++ b/llama.h @@ -100,6 +100,8 @@ extern "C" { // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency float rope_freq_scale; // RoPE frequency scaling factor + float rope_ntk_factor; // RoPE NTK mix factor + float rope_ext_factor; // RoPE extrapolation mix factor // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback;