server: add repeat penalty sigmoid
This commit is contained in:
parent
ea9c32be71
commit
3722c729b8
7 changed files with 128 additions and 4 deletions
|
@ -117,6 +117,7 @@ struct gpt_sampler_params {
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
|
float penalty_repeat_sigmoid_growth = 0.00f; // 0.0 = disabled
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
|
|
@ -168,6 +168,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
params.penalty_repeat,
|
params.penalty_repeat,
|
||||||
params.penalty_freq,
|
params.penalty_freq,
|
||||||
params.penalty_present,
|
params.penalty_present,
|
||||||
|
params.penalty_repeat_sigmoid_growth,
|
||||||
params.penalize_nl,
|
params.penalize_nl,
|
||||||
params.ignore_eos));
|
params.ignore_eos));
|
||||||
|
|
||||||
|
|
|
@ -350,6 +350,8 @@ node index.js
|
||||||
|
|
||||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
|
`repeat_penalty_sigmoid_growth`: Apply the sigmoid function to `repeat_penalty` within `repeat_last_n` range. The value of `1` means linear change in penalty from 1 to `repeat_penalty`. Higher values > 1 increase the difference in the resulting penalty between the first and the second half of the penalty range. Lower values < 1 change the resulting penalty slower in the middle of the range. Negative values will be changing the penalty in the same way, but from `repeat_penalty` to 1. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
||||||
|
|
||||||
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
||||||
|
|
|
@ -898,6 +898,7 @@ struct server_context {
|
||||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||||
|
slot.sparams.penalty_repeat_sigmoid_growth = json_value(data, "repeat_penalty_sigmoid_growth", default_sparams.penalty_repeat_sigmoid_growth);
|
||||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
|
@ -1239,6 +1240,7 @@ struct server_context {
|
||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
|
{"repeat_penalty_sigmoid_growth", slot.sparams.penalty_repeat_sigmoid_growth},
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
|
|
|
@ -1124,6 +1124,7 @@ extern "C" {
|
||||||
float penalty_repeat, // 1.0 = disabled
|
float penalty_repeat, // 1.0 = disabled
|
||||||
float penalty_freq, // 0.0 = disabled
|
float penalty_freq, // 0.0 = disabled
|
||||||
float penalty_present, // 0.0 = disabled
|
float penalty_present, // 0.0 = disabled
|
||||||
|
float penalty_repeat_sigmoid_growth, // 0.0 = disabled
|
||||||
bool penalize_nl, // consider newlines as a repeatable token
|
bool penalize_nl, // consider newlines as a repeatable token
|
||||||
bool ignore_eos); // ignore the end-of-sequence token
|
bool ignore_eos); // ignore the end-of-sequence token
|
||||||
|
|
||||||
|
|
|
@ -1381,6 +1381,7 @@ struct llama_sampler_penalties {
|
||||||
const float penalty_repeat;
|
const float penalty_repeat;
|
||||||
const float penalty_freq;
|
const float penalty_freq;
|
||||||
const float penalty_present;
|
const float penalty_present;
|
||||||
|
const float penalty_repeat_sigmoid_growth;
|
||||||
|
|
||||||
const bool penalize_nl;
|
const bool penalize_nl;
|
||||||
const bool ignore_eos;
|
const bool ignore_eos;
|
||||||
|
@ -1450,6 +1451,115 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct sigmoid {
|
||||||
|
protected:
|
||||||
|
bool enabled;
|
||||||
|
float growth;
|
||||||
|
bool use_mirrored;
|
||||||
|
const ring_buffer<llama_token> & last_tokens;
|
||||||
|
size_t last_tokens_size;
|
||||||
|
size_t penalty_last_n;
|
||||||
|
float token_x;
|
||||||
|
float y_min = 0;
|
||||||
|
float y_diff = 0;
|
||||||
|
|
||||||
|
inline float calc_sigmoid(float x) {
|
||||||
|
float y = 1 / (1 + exp((-x + 0.5) * growth));
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float calc_sigmoid_inv_growth(float x) {
|
||||||
|
float y = 1 / (1 + exp((-x + 0.5) / growth));
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sigmoid mirrored by y=x
|
||||||
|
inline float calc_mirrored_sigmoid(float x) {
|
||||||
|
if ((x == 0 && growth > 0) || (x >= 1 && growth < 0)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if ((x == 0 && growth < 0) || (x >= 1 && growth > 0)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
// the actual formula: y = 0.5 - log((1 - x) / x) / growth
|
||||||
|
// but we invert the growth to transform the initial (0;1) range to the (1;+inf) range
|
||||||
|
float inv_growth = 1 / growth;
|
||||||
|
float y = 0.5 - log((1 - x) / x) / inv_growth;
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float calc_norm_coeff(float x) {
|
||||||
|
if (use_mirrored) {
|
||||||
|
float norm_x = (x + y_min) * y_diff; // normalize x within a range of the non-mirrored sigmoid's y
|
||||||
|
float y = calc_mirrored_sigmoid(norm_x);
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
float y = calc_sigmoid(x);
|
||||||
|
float norm_y = (y - y_min) / y_diff;
|
||||||
|
return norm_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float apply_norm_coeff(float coeff, float penalty) {
|
||||||
|
float initial_diff = penalty - 1;
|
||||||
|
float result_diff = initial_diff * coeff;
|
||||||
|
return 1 + result_diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit sigmoid(
|
||||||
|
float growth,
|
||||||
|
const ring_buffer<llama_token> & last_tokens,
|
||||||
|
size_t penalty_last_n
|
||||||
|
) :
|
||||||
|
enabled(growth != 0),
|
||||||
|
growth(growth),
|
||||||
|
use_mirrored(abs(growth) < 1),
|
||||||
|
last_tokens(last_tokens),
|
||||||
|
last_tokens_size(std::min(penalty_last_n, last_tokens.size())),
|
||||||
|
penalty_last_n(penalty_last_n),
|
||||||
|
token_x(1 / (float)penalty_last_n) {
|
||||||
|
if (!enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
float y1;
|
||||||
|
float y2;
|
||||||
|
if (use_mirrored) {
|
||||||
|
y1 = calc_sigmoid_inv_growth(0);
|
||||||
|
y2 = calc_sigmoid_inv_growth(1);
|
||||||
|
} else {
|
||||||
|
y1 = calc_sigmoid(0);
|
||||||
|
y2 = calc_sigmoid(1);
|
||||||
|
}
|
||||||
|
y_min = std::min(y1, y2);
|
||||||
|
float y_max = std::max(y1, y2);
|
||||||
|
y_diff = y_max - y_min;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float apply(float penalty, llama_token token) {
|
||||||
|
if (!enabled) {
|
||||||
|
return penalty;
|
||||||
|
}
|
||||||
|
// the position (from the end) within the penalty tokens array
|
||||||
|
size_t token_rindex = 0;
|
||||||
|
while (token_rindex < last_tokens_size) {
|
||||||
|
if (last_tokens.rat(token_rindex) == token) {
|
||||||
|
break; // must always break at some point, otherwise it's UB
|
||||||
|
}
|
||||||
|
token_rindex++;
|
||||||
|
}
|
||||||
|
// the position within the penalty range,
|
||||||
|
// it's 1-indexed, so the last token in the range will correspond to x=1
|
||||||
|
size_t token_pos = penalty_last_n - token_rindex;
|
||||||
|
float x = token_x * token_pos;
|
||||||
|
float coeff = calc_norm_coeff(x);
|
||||||
|
float resulting_penalty = apply_norm_coeff(coeff, penalty);
|
||||||
|
return resulting_penalty;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
sigmoid penalty_repeat_sigmoid(ctx->penalty_repeat_sigmoid_growth, ctx->prev, ctx->penalty_last_n);
|
||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
// TODO: optimize this by maintaining the token count in the sampler context
|
// TODO: optimize this by maintaining the token count in the sampler context
|
||||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||||
|
@ -1461,7 +1571,8 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the cur_p
|
// Apply frequency and presence penalties to the cur_p
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
const auto token = cur_p->data[i].id;
|
||||||
|
const auto token_iter = token_count.find(token);
|
||||||
if (token_iter == token_count.end()) {
|
if (token_iter == token_count.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -1470,11 +1581,14 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
|
float applied_penalty_repeat;
|
||||||
if (cur_p->data[i].logit <= 0) {
|
if (cur_p->data[i].logit <= 0) {
|
||||||
cur_p->data[i].logit *= ctx->penalty_repeat;
|
applied_penalty_repeat = ctx->penalty_repeat;
|
||||||
} else {
|
} else {
|
||||||
cur_p->data[i].logit /= ctx->penalty_repeat;
|
applied_penalty_repeat = 1 / ctx->penalty_repeat;
|
||||||
}
|
}
|
||||||
|
applied_penalty_repeat = penalty_repeat_sigmoid.apply(applied_penalty_repeat, token);
|
||||||
|
cur_p->data[i].logit *= applied_penalty_repeat;
|
||||||
|
|
||||||
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
|
||||||
}
|
}
|
||||||
|
@ -1502,6 +1616,7 @@ static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_s
|
||||||
ctx->penalty_repeat,
|
ctx->penalty_repeat,
|
||||||
ctx->penalty_freq,
|
ctx->penalty_freq,
|
||||||
ctx->penalty_present,
|
ctx->penalty_present,
|
||||||
|
ctx->penalty_repeat_sigmoid_growth,
|
||||||
ctx->penalize_nl,
|
ctx->penalize_nl,
|
||||||
ctx->ignore_eos);
|
ctx->ignore_eos);
|
||||||
|
|
||||||
|
@ -1536,6 +1651,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present,
|
float penalty_present,
|
||||||
|
float penalty_repeat_sigmoid_growth,
|
||||||
bool penalize_nl,
|
bool penalize_nl,
|
||||||
bool ignore_eos) {
|
bool ignore_eos) {
|
||||||
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
||||||
|
@ -1558,6 +1674,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
/* .penalty_repeat = */ penalty_repeat,
|
/* .penalty_repeat = */ penalty_repeat,
|
||||||
/* .penalty_freq = */ penalty_freq,
|
/* .penalty_freq = */ penalty_freq,
|
||||||
/* .penalty_present = */ penalty_present,
|
/* .penalty_present = */ penalty_present,
|
||||||
|
/* .penalty_repeat_sigmoid_growth */ penalty_repeat_sigmoid_growth,
|
||||||
/* .penalize_nl = */ penalize_nl,
|
/* .penalize_nl = */ penalize_nl,
|
||||||
/* .ignore_eos = */ ignore_eos,
|
/* .ignore_eos = */ ignore_eos,
|
||||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||||
|
|
|
@ -149,7 +149,7 @@ static void test_penalties(
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
|
|
||||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, 0.0, false, false);
|
||||||
|
|
||||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
llama_sampler_accept(sampler, last_tokens[i]);
|
llama_sampler_accept(sampler, last_tokens[i]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue