From 3722c729b8121256766405dc263dfbb565577bdb Mon Sep 17 00:00:00 2001 From: ZXED Date: Sun, 15 Sep 2024 15:00:52 +0300 Subject: [PATCH 1/3] server: add repeat penalty sigmoid --- common/common.h | 1 + common/sampling.cpp | 1 + examples/server/README.md | 2 + examples/server/server.cpp | 2 + include/llama.h | 1 + src/llama-sampling.cpp | 123 ++++++++++++++++++++++++++++++++++++- tests/test-sampling.cpp | 2 +- 7 files changed, 128 insertions(+), 4 deletions(-) diff --git a/common/common.h b/common/common.h index cb87c4479..f8c1e6126 100644 --- a/common/common.h +++ b/common/common.h @@ -117,6 +117,7 @@ struct gpt_sampler_params { float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled + float penalty_repeat_sigmoid_growth = 0.00f; // 0.0 = disabled int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate diff --git a/common/sampling.cpp b/common/sampling.cpp index 3dc7f1120..cfc2652c2 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -168,6 +168,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.penalty_repeat_sigmoid_growth, params.penalize_nl, params.ignore_eos)); diff --git a/examples/server/README.md b/examples/server/README.md index dfca07f98..ec45c5237 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -350,6 +350,8 @@ node index.js `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. + `repeat_penalty_sigmoid_growth`: Apply the sigmoid function to `repeat_penalty` within `repeat_last_n` range. The value of `1` means linear change in penalty from 1 to `repeat_penalty`. Higher values > 1 increase the difference in the resulting penalty between the first and the second half of the penalty range. Lower values < 1 change the resulting penalty slower in the middle of the range. Negative values will be changing the penalty in the same way, but from `repeat_penalty` to 1. Default: `0.0`, which is disabled. + `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61ff09bb2..0c600b9fa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -898,6 +898,7 @@ struct server_context { slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.penalty_repeat_sigmoid_growth = json_value(data, "repeat_penalty_sigmoid_growth", default_sparams.penalty_repeat_sigmoid_growth); slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); @@ -1239,6 +1240,7 @@ struct server_context { {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, + {"repeat_penalty_sigmoid_growth", slot.sparams.penalty_repeat_sigmoid_growth}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, diff --git a/include/llama.h b/include/llama.h index 132937a07..cbedb378b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1124,6 +1124,7 @@ extern "C" { float penalty_repeat, // 1.0 = disabled float penalty_freq, // 0.0 = disabled float penalty_present, // 0.0 = disabled + float penalty_repeat_sigmoid_growth, // 0.0 = disabled bool penalize_nl, // consider newlines as a repeatable token bool ignore_eos); // ignore the end-of-sequence token diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e255a8fc4..08ee2110b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1381,6 +1381,7 @@ struct llama_sampler_penalties { const float penalty_repeat; const float penalty_freq; const float penalty_present; + const float penalty_repeat_sigmoid_growth; const bool penalize_nl; const bool ignore_eos; @@ -1450,6 +1451,115 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok } } + struct sigmoid { + protected: + bool enabled; + float growth; + bool use_mirrored; + const ring_buffer & last_tokens; + size_t last_tokens_size; + size_t penalty_last_n; + float token_x; + float y_min = 0; + float y_diff = 0; + + inline float calc_sigmoid(float x) { + float y = 1 / (1 + exp((-x + 0.5) * growth)); + return y; + } + + inline float calc_sigmoid_inv_growth(float x) { + float y = 1 / (1 + exp((-x + 0.5) / growth)); + return y; + } + + // sigmoid mirrored by y=x + inline float calc_mirrored_sigmoid(float x) { + if ((x == 0 && growth > 0) || (x >= 1 && growth < 0)) { + return 0; + } + if ((x == 0 && growth < 0) || (x >= 1 && growth > 0)) { + return 1; + } + // the actual formula: y = 0.5 - log((1 - x) / x) / growth + // but we invert the growth to transform the initial (0;1) range to the (1;+inf) range + float inv_growth = 1 / growth; + float y = 0.5 - log((1 - x) / x) / inv_growth; + return y; + } + + inline float calc_norm_coeff(float x) { + if (use_mirrored) { + float norm_x = (x + y_min) * y_diff; // normalize x within a range of the non-mirrored sigmoid's y + float y = calc_mirrored_sigmoid(norm_x); + return y; + } + + float y = calc_sigmoid(x); + float norm_y = (y - y_min) / y_diff; + return norm_y; + } + + static inline float apply_norm_coeff(float coeff, float penalty) { + float initial_diff = penalty - 1; + float result_diff = initial_diff * coeff; + return 1 + result_diff; + } + + public: + explicit sigmoid( + float growth, + const ring_buffer & last_tokens, + size_t penalty_last_n + ) : + enabled(growth != 0), + growth(growth), + use_mirrored(abs(growth) < 1), + last_tokens(last_tokens), + last_tokens_size(std::min(penalty_last_n, last_tokens.size())), + penalty_last_n(penalty_last_n), + token_x(1 / (float)penalty_last_n) { + if (!enabled) { + return; + } + float y1; + float y2; + if (use_mirrored) { + y1 = calc_sigmoid_inv_growth(0); + y2 = calc_sigmoid_inv_growth(1); + } else { + y1 = calc_sigmoid(0); + y2 = calc_sigmoid(1); + } + y_min = std::min(y1, y2); + float y_max = std::max(y1, y2); + y_diff = y_max - y_min; + } + + inline float apply(float penalty, llama_token token) { + if (!enabled) { + return penalty; + } + // the position (from the end) within the penalty tokens array + size_t token_rindex = 0; + while (token_rindex < last_tokens_size) { + if (last_tokens.rat(token_rindex) == token) { + break; // must always break at some point, otherwise it's UB + } + token_rindex++; + } + // the position within the penalty range, + // it's 1-indexed, so the last token in the range will correspond to x=1 + size_t token_pos = penalty_last_n - token_rindex; + float x = token_x * token_pos; + float coeff = calc_norm_coeff(x); + float resulting_penalty = apply_norm_coeff(coeff, penalty); + return resulting_penalty; + } + }; + + sigmoid penalty_repeat_sigmoid(ctx->penalty_repeat_sigmoid_growth, ctx->prev, ctx->penalty_last_n); + // Create a frequency map to count occurrences of each token in last_tokens // TODO: optimize this by maintaining the token count in the sampler context using llama_token_cnt = std::unordered_map; @@ -1461,7 +1571,8 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok // Apply frequency and presence penalties to the cur_p for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); + const auto token = cur_p->data[i].id; + const auto token_iter = token_count.find(token); if (token_iter == token_count.end()) { continue; } @@ -1470,11 +1581,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + float applied_penalty_repeat; if (cur_p->data[i].logit <= 0) { - cur_p->data[i].logit *= ctx->penalty_repeat; + applied_penalty_repeat = ctx->penalty_repeat; } else { - cur_p->data[i].logit /= ctx->penalty_repeat; + applied_penalty_repeat = 1 / ctx->penalty_repeat; } + applied_penalty_repeat = penalty_repeat_sigmoid.apply(applied_penalty_repeat, token); + cur_p->data[i].logit *= applied_penalty_repeat; cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; } @@ -1502,6 +1616,7 @@ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_s ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present, + ctx->penalty_repeat_sigmoid_growth, ctx->penalize_nl, ctx->ignore_eos); @@ -1536,6 +1651,7 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_repeat, float penalty_freq, float penalty_present, + float penalty_repeat_sigmoid_growth, bool penalize_nl, bool ignore_eos) { if (linefeed_id == LLAMA_TOKEN_NULL) { @@ -1558,6 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, + /* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth, /* .penalize_nl = */ penalize_nl, /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6e021c4c7..ef458bc64 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -149,7 +149,7 @@ static void test_penalties( llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); + auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false); for (size_t i = 0; i < last_tokens.size(); i++) { llama_sampler_accept(sampler, last_tokens[i]); From c795d8b82bf2d6d61e49b3179ec9562641902a53 Mon Sep 17 00:00:00 2001 From: ZXED Date: Wed, 25 Sep 2024 19:57:52 +0300 Subject: [PATCH 2/3] add tests --- src/llama-sampling.cpp | 2 +- tests/test-sampling.cpp | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 08ee2110b..fd2e3dbef 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1674,7 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, - /* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth, + /* .penalty_repeat_sigmoid_growth = */ penalty_repeat_sigmoid_growth, /* .penalize_nl = */ penalize_nl, /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index ef458bc64..c4574bf59 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -134,7 +134,7 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence + const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_repeat_sigmoid_growth ) { GGML_ASSERT(probs.size() == expected_probs.size()); @@ -149,7 +149,7 @@ static void test_penalties( llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false); + auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_repeat_sigmoid_growth, false, false); for (size_t i = 0; i < last_tokens.size(); i++) { llama_sampler_accept(sampler, last_tokens[i]); @@ -316,13 +316,17 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f); + + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 0.0f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); From 8b0d3ab5ab63ceee3a9b24550b80b8f92f4c741c Mon Sep 17 00:00:00 2001 From: ZXED Date: Wed, 25 Sep 2024 21:00:46 +0300 Subject: [PATCH 3/3] fix formatting --- tests/test-sampling.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index c4574bf59..e06f52fc6 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -320,9 +320,9 @@ int main(void) { test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f); - test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 0.0f, 0.0f, 10.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.263353f, 0.263353f, 0.201890f, 0.153630f, 0.117775f}, 1.5f, 0.0f, 0.0f, 1.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.290188f, 0.246533f, 0.182452f, 0.140414f, 0.140414f}, 0.5f, 0.0f, 0.0f, -0.5f); test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 0.0f); test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 0.0f);